mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-11 01:24:36 +00:00
Merge branch 'upstream' into concedo_experimental
# Conflicts: # README.md # examples/llama.android/llama/src/main/cpp/llama-android.cpp # examples/run/run.cpp # examples/server/README.md # examples/server/bench/README.md # examples/server/tests/README.md # ggml/src/CMakeLists.txt # ggml/src/ggml-cpu/CMakeLists.txt # tests/test-backend-ops.cpp
This commit is contained in:
commit
911da8765f
46 changed files with 82848 additions and 73880 deletions
|
@ -20,6 +20,7 @@
|
||||||
#include <cstdarg>
|
#include <cstdarg>
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
#include <ctime>
|
#include <ctime>
|
||||||
|
#include <filesystem>
|
||||||
#include <fstream>
|
#include <fstream>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <iterator>
|
#include <iterator>
|
||||||
|
@ -64,7 +65,9 @@
|
||||||
#ifdef __linux__
|
#ifdef __linux__
|
||||||
#include <linux/limits.h>
|
#include <linux/limits.h>
|
||||||
#elif defined(_WIN32)
|
#elif defined(_WIN32)
|
||||||
#define PATH_MAX MAX_PATH
|
# if !defined(PATH_MAX)
|
||||||
|
# define PATH_MAX MAX_PATH
|
||||||
|
# endif
|
||||||
#else
|
#else
|
||||||
#include <sys/syslimits.h>
|
#include <sys/syslimits.h>
|
||||||
#endif
|
#endif
|
||||||
|
@ -1150,8 +1153,7 @@ static bool common_download_file(const std::string & url, const std::string & pa
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Check if the file already exists locally
|
// Check if the file already exists locally
|
||||||
struct stat model_file_info;
|
auto file_exists = std::filesystem::exists(path);
|
||||||
auto file_exists = (stat(path.c_str(), &model_file_info) == 0);
|
|
||||||
|
|
||||||
// If the file exists, check its JSON metadata companion file.
|
// If the file exists, check its JSON metadata companion file.
|
||||||
std::string metadata_path = path + ".json";
|
std::string metadata_path = path + ".json";
|
||||||
|
@ -1614,6 +1616,18 @@ std::string common_detokenize(llama_context * ctx, const std::vector<llama_token
|
||||||
// Chat template utils
|
// Chat template utils
|
||||||
//
|
//
|
||||||
|
|
||||||
|
std::string common_get_builtin_chat_template(const struct llama_model * model) {
|
||||||
|
static const char * template_key = "tokenizer.chat_template";
|
||||||
|
// call with NULL buffer to get the total size of the string
|
||||||
|
int32_t res = llama_model_meta_val_str(model, template_key, NULL, 0);
|
||||||
|
if (res > 0) {
|
||||||
|
std::vector<char> model_template(res + 1, 0);
|
||||||
|
llama_model_meta_val_str(model, template_key, model_template.data(), model_template.size());
|
||||||
|
return std::string(model_template.data(), model_template.size() - 1);
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
bool common_chat_verify_template(const std::string & tmpl) {
|
bool common_chat_verify_template(const std::string & tmpl) {
|
||||||
llama_chat_message chat[] = {{"user", "test"}};
|
llama_chat_message chat[] = {{"user", "test"}};
|
||||||
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
int res = llama_chat_apply_template(nullptr, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
||||||
|
|
|
@ -567,6 +567,9 @@ struct common_chat_msg {
|
||||||
std::string content;
|
std::string content;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Get the built-in chat template for the model. Return empty string if not present.
|
||||||
|
std::string common_get_builtin_chat_template(const struct llama_model * model);
|
||||||
|
|
||||||
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
|
||||||
bool common_chat_verify_template(const std::string & tmpl);
|
bool common_chat_verify_template(const std::string & tmpl);
|
||||||
|
|
||||||
|
|
|
@ -1764,25 +1764,19 @@ class DeciModel(Model):
|
||||||
self.gguf_writer.add_token_list(tokens)
|
self.gguf_writer.add_token_list(tokens)
|
||||||
self.gguf_writer.add_token_types(toktypes)
|
self.gguf_writer.add_token_types(toktypes)
|
||||||
|
|
||||||
special_vocab = gguf.SpecialVocab(
|
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
|
||||||
self.dir_model, load_merges=True,
|
|
||||||
special_token_types = ['bos', 'eos', 'eom', 'eot']
|
|
||||||
)
|
|
||||||
special_vocab._set_special_token("bos", 128000)
|
|
||||||
special_vocab._set_special_token("eos", 128001)
|
|
||||||
special_vocab._set_special_token("eom", 128008)
|
|
||||||
special_vocab._set_special_token("eot", 128009)
|
|
||||||
special_vocab.add_to_gguf(self.gguf_writer)
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
else:
|
else:
|
||||||
# DeciLM-7B
|
# DeciLM-7B
|
||||||
self._set_vocab_llama_hf()
|
self._set_vocab_llama_hf()
|
||||||
# self._set_vocab_gpt2()
|
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B
|
if "block_configs" in self.hparams: # Llama-3_1-Nemotron-51B
|
||||||
assert self.block_count == len(self._num_kv_heads)
|
assert self.block_count == len(self._num_kv_heads)
|
||||||
assert self.block_count == len(self._num_heads)
|
assert self.block_count == len(self._num_heads)
|
||||||
assert self.block_count == len(self._ffn_dims)
|
assert self.block_count == len(self._ffn_dims)
|
||||||
|
if (rope_theta := self.hparams.get("rope_theta")) is not None:
|
||||||
|
self.gguf_writer.add_rope_freq_base(rope_theta)
|
||||||
self.gguf_writer.add_head_count_kv(self._num_kv_heads)
|
self.gguf_writer.add_head_count_kv(self._num_kv_heads)
|
||||||
self.gguf_writer.add_head_count(self._num_heads)
|
self.gguf_writer.add_head_count(self._num_heads)
|
||||||
self.gguf_writer.add_feed_forward_length(self._ffn_dims)
|
self.gguf_writer.add_feed_forward_length(self._ffn_dims)
|
||||||
|
|
|
@ -189,12 +189,12 @@ xychart-beta
|
||||||
"pp": {
|
"pp": {
|
||||||
"p95": round(data['metrics']["llamacpp_prompt_processing_second"]["p(95)"], 2),
|
"p95": round(data['metrics']["llamacpp_prompt_processing_second"]["p(95)"], 2),
|
||||||
"avg": round(data['metrics']["llamacpp_prompt_processing_second"]["avg"], 2),
|
"avg": round(data['metrics']["llamacpp_prompt_processing_second"]["avg"], 2),
|
||||||
"0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2),
|
"0": round(mean(prometheus_metrics['prompt_tokens_seconds']), 2) if 'prompt_tokens_seconds' in prometheus_metrics else 0,
|
||||||
},
|
},
|
||||||
"tg": {
|
"tg": {
|
||||||
"p95": round(data['metrics']["llamacpp_tokens_second"]["p(95)"], 2),
|
"p95": round(data['metrics']["llamacpp_tokens_second"]["p(95)"], 2),
|
||||||
"avg": round(data['metrics']["llamacpp_tokens_second"]["avg"], 2),
|
"avg": round(data['metrics']["llamacpp_tokens_second"]["avg"], 2),
|
||||||
"0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2),
|
"0": round(mean(prometheus_metrics['predicted_tokens_seconds']), 2) if 'predicted_tokens_seconds' in prometheus_metrics else 0,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
with open("results.github.env", 'a') as github_env:
|
with open("results.github.env", 'a') as github_env:
|
||||||
|
@ -214,11 +214,14 @@ def start_benchmark(args):
|
||||||
k6_args = [
|
k6_args = [
|
||||||
'run', args.scenario,
|
'run', args.scenario,
|
||||||
'--no-color',
|
'--no-color',
|
||||||
|
'--no-connection-reuse',
|
||||||
|
'--no-vu-connection-reuse',
|
||||||
]
|
]
|
||||||
k6_args.extend(['--duration', args.duration])
|
k6_args.extend(['--duration', args.duration])
|
||||||
k6_args.extend(['--iterations', args.n_prompts])
|
k6_args.extend(['--iterations', args.n_prompts])
|
||||||
k6_args.extend(['--vus', args.parallel])
|
k6_args.extend(['--vus', args.parallel])
|
||||||
k6_args.extend(['--summary-export', 'k6-results.json'])
|
k6_args.extend(['--summary-export', 'k6-results.json'])
|
||||||
|
k6_args.extend(['--out', 'csv=k6-results.csv'])
|
||||||
args = f"SERVER_BENCH_N_PROMPTS={args.n_prompts} SERVER_BENCH_MAX_PROMPT_TOKENS={args.max_prompt_tokens} SERVER_BENCH_MAX_CONTEXT={args.max_tokens} "
|
args = f"SERVER_BENCH_N_PROMPTS={args.n_prompts} SERVER_BENCH_MAX_PROMPT_TOKENS={args.max_prompt_tokens} SERVER_BENCH_MAX_CONTEXT={args.max_tokens} "
|
||||||
args = args + ' '.join([str(arg) for arg in [k6_path, *k6_args]])
|
args = args + ' '.join([str(arg) for arg in [k6_path, *k6_args]])
|
||||||
print(f"bench: starting k6 with: {args}")
|
print(f"bench: starting k6 with: {args}")
|
||||||
|
@ -231,7 +234,7 @@ def start_server(args):
|
||||||
server_process = start_server_background(args)
|
server_process = start_server_background(args)
|
||||||
|
|
||||||
attempts = 0
|
attempts = 0
|
||||||
max_attempts = 20
|
max_attempts = 600
|
||||||
if 'GITHUB_ACTIONS' in os.environ:
|
if 'GITHUB_ACTIONS' in os.environ:
|
||||||
max_attempts *= 2
|
max_attempts *= 2
|
||||||
|
|
||||||
|
@ -242,7 +245,15 @@ def start_server(args):
|
||||||
print(f"bench: waiting for server to start ...")
|
print(f"bench: waiting for server to start ...")
|
||||||
time.sleep(0.5)
|
time.sleep(0.5)
|
||||||
|
|
||||||
print("bench: server started.")
|
attempts = 0
|
||||||
|
while not is_server_ready(args.host, args.port):
|
||||||
|
attempts += 1
|
||||||
|
if attempts > max_attempts:
|
||||||
|
assert False, "server not ready"
|
||||||
|
print(f"bench: waiting for server to be ready ...")
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
print("bench: server started and ready.")
|
||||||
return server_process
|
return server_process
|
||||||
|
|
||||||
|
|
||||||
|
@ -255,11 +266,6 @@ def start_server_background(args):
|
||||||
'--host', args.host,
|
'--host', args.host,
|
||||||
'--port', args.port,
|
'--port', args.port,
|
||||||
]
|
]
|
||||||
model_file = args.model_path_prefix + os.path.sep + args.hf_file
|
|
||||||
model_dir = os.path.dirname(model_file)
|
|
||||||
if not os.path.exists(model_dir):
|
|
||||||
os.makedirs(model_dir)
|
|
||||||
server_args.extend(['--model', model_file])
|
|
||||||
server_args.extend(['--hf-repo', args.hf_repo])
|
server_args.extend(['--hf-repo', args.hf_repo])
|
||||||
server_args.extend(['--hf-file', args.hf_file])
|
server_args.extend(['--hf-file', args.hf_file])
|
||||||
server_args.extend(['--n-gpu-layers', args.n_gpu_layers])
|
server_args.extend(['--n-gpu-layers', args.n_gpu_layers])
|
||||||
|
@ -303,6 +309,12 @@ def is_server_listening(server_fqdn, server_port):
|
||||||
return _is_server_listening
|
return _is_server_listening
|
||||||
|
|
||||||
|
|
||||||
|
def is_server_ready(server_fqdn, server_port):
|
||||||
|
url = f"http://{server_fqdn}:{server_port}/health"
|
||||||
|
response = requests.get(url)
|
||||||
|
return response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
def escape_metric_name(metric_name):
|
def escape_metric_name(metric_name):
|
||||||
return re.sub('[^A-Z0-9]', '_', metric_name.upper())
|
return re.sub('[^A-Z0-9]', '_', metric_name.upper())
|
||||||
|
|
||||||
|
|
|
@ -56,6 +56,7 @@ const llamacpp_completion_tokens = new Trend('llamacpp_completion_tokens')
|
||||||
|
|
||||||
const llamacpp_tokens_second = new Trend('llamacpp_tokens_second')
|
const llamacpp_tokens_second = new Trend('llamacpp_tokens_second')
|
||||||
const llamacpp_prompt_processing_second = new Trend('llamacpp_prompt_processing_second')
|
const llamacpp_prompt_processing_second = new Trend('llamacpp_prompt_processing_second')
|
||||||
|
const llamacpp_emit_first_token_second = new Trend('llamacpp_emit_first_token_second')
|
||||||
|
|
||||||
const llamacpp_prompt_tokens_total_counter = new Counter('llamacpp_prompt_tokens_total_counter')
|
const llamacpp_prompt_tokens_total_counter = new Counter('llamacpp_prompt_tokens_total_counter')
|
||||||
const llamacpp_completion_tokens_total_counter = new Counter('llamacpp_completion_tokens_total_counter')
|
const llamacpp_completion_tokens_total_counter = new Counter('llamacpp_completion_tokens_total_counter')
|
||||||
|
@ -89,6 +90,9 @@ export default function () {
|
||||||
],
|
],
|
||||||
"model": model,
|
"model": model,
|
||||||
"stream": true,
|
"stream": true,
|
||||||
|
"stream_options": {
|
||||||
|
"include_usage": true, // False to be supported in llama.cpp server
|
||||||
|
},
|
||||||
"seed": 42,
|
"seed": 42,
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"stop": ["<|im_end|>"] // This is temporary for phi-2 base (i.e. not instructed) since the server expects that the model always to emit BOS
|
"stop": ["<|im_end|>"] // This is temporary for phi-2 base (i.e. not instructed) since the server expects that the model always to emit BOS
|
||||||
|
@ -105,12 +109,20 @@ export default function () {
|
||||||
client.on('event', function (event) {
|
client.on('event', function (event) {
|
||||||
if (promptEvalEndTime == null) {
|
if (promptEvalEndTime == null) {
|
||||||
promptEvalEndTime = new Date()
|
promptEvalEndTime = new Date()
|
||||||
|
llamacpp_emit_first_token_second.add((promptEvalEndTime - startTime) / 1.e3)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (event.data === '[DONE]' || event.data === '') {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
let chunk = JSON.parse(event.data)
|
let chunk = JSON.parse(event.data)
|
||||||
let choice = chunk.choices[0]
|
|
||||||
if (choice.finish_reason) {
|
if (chunk.choices && chunk.choices.length > 0) {
|
||||||
finish_reason = choice.finish_reason
|
let choice = chunk.choices[0]
|
||||||
|
if (choice.finish_reason) {
|
||||||
|
finish_reason = choice.finish_reason
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (chunk.usage) {
|
if (chunk.usage) {
|
||||||
|
|
|
@ -67,6 +67,13 @@ enum server_task_type {
|
||||||
SERVER_TASK_TYPE_SET_LORA,
|
SERVER_TASK_TYPE_SET_LORA,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum oaicompat_type {
|
||||||
|
OAICOMPAT_TYPE_NONE,
|
||||||
|
OAICOMPAT_TYPE_CHAT,
|
||||||
|
OAICOMPAT_TYPE_COMPLETION,
|
||||||
|
OAICOMPAT_TYPE_EMBEDDING,
|
||||||
|
};
|
||||||
|
|
||||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||||
enum error_type {
|
enum error_type {
|
||||||
ERROR_TYPE_INVALID_REQUEST,
|
ERROR_TYPE_INVALID_REQUEST,
|
||||||
|
@ -91,6 +98,8 @@ struct slot_params {
|
||||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||||
|
|
||||||
|
std::vector<common_lora_adapter_container> lora;
|
||||||
|
|
||||||
std::vector<std::string> antiprompt;
|
std::vector<std::string> antiprompt;
|
||||||
std::vector<std::string> response_fields;
|
std::vector<std::string> response_fields;
|
||||||
bool timings_per_token = false;
|
bool timings_per_token = false;
|
||||||
|
@ -101,11 +110,10 @@ struct slot_params {
|
||||||
struct common_params_speculative speculative;
|
struct common_params_speculative speculative;
|
||||||
|
|
||||||
// OAI-compat fields
|
// OAI-compat fields
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
bool oaicompat = false;
|
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||||
bool oaicompat_chat = true;
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_cmpl_id;
|
||||||
std::string oaicompat_cmpl_id;
|
|
||||||
|
|
||||||
json to_json() const {
|
json to_json() const {
|
||||||
std::vector<std::string> samplers;
|
std::vector<std::string> samplers;
|
||||||
|
@ -114,6 +122,11 @@ struct slot_params {
|
||||||
samplers.emplace_back(common_sampler_type_to_str(sampler));
|
samplers.emplace_back(common_sampler_type_to_str(sampler));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
json lora = json::array();
|
||||||
|
for (size_t i = 0; i < this->lora.size(); ++i) {
|
||||||
|
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
|
||||||
|
}
|
||||||
|
|
||||||
return json {
|
return json {
|
||||||
{"n_predict", n_predict}, // Server configured n_predict
|
{"n_predict", n_predict}, // Server configured n_predict
|
||||||
{"seed", sampling.seed},
|
{"seed", sampling.seed},
|
||||||
|
@ -154,6 +167,7 @@ struct slot_params {
|
||||||
{"speculative.p_min", speculative.p_min},
|
{"speculative.p_min", speculative.p_min},
|
||||||
{"timings_per_token", timings_per_token},
|
{"timings_per_token", timings_per_token},
|
||||||
{"post_sampling_probs", post_sampling_probs},
|
{"post_sampling_probs", post_sampling_probs},
|
||||||
|
{"lora", lora},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -183,12 +197,16 @@ struct server_task {
|
||||||
// used by SERVER_TASK_TYPE_METRICS
|
// used by SERVER_TASK_TYPE_METRICS
|
||||||
bool metrics_reset_bucket = false;
|
bool metrics_reset_bucket = false;
|
||||||
|
|
||||||
|
// used by SERVER_TASK_TYPE_SET_LORA
|
||||||
|
std::vector<common_lora_adapter_container> set_lora;
|
||||||
|
|
||||||
server_task(server_task_type type) : type(type) {}
|
server_task(server_task_type type) : type(type) {}
|
||||||
|
|
||||||
static slot_params params_from_json_cmpl(
|
static slot_params params_from_json_cmpl(
|
||||||
const llama_model * model,
|
const llama_model * model,
|
||||||
const llama_context * ctx,
|
const llama_context * ctx,
|
||||||
const common_params & params_base,
|
const common_params & params_base,
|
||||||
|
const std::vector<common_lora_adapter_container> & lora_base,
|
||||||
const json & data) {
|
const json & data) {
|
||||||
slot_params params;
|
slot_params params;
|
||||||
|
|
||||||
|
@ -245,6 +263,16 @@ struct server_task {
|
||||||
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
||||||
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
||||||
|
|
||||||
|
if (data.contains("lora")) {
|
||||||
|
if (data.at("lora").is_array()) {
|
||||||
|
params.lora = parse_lora_request(lora_base, data.at("lora"));
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
params.lora = lora_base;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: add more sanity checks for the input parameters
|
// TODO: add more sanity checks for the input parameters
|
||||||
|
|
||||||
if (params.sampling.penalty_last_n < -1) {
|
if (params.sampling.penalty_last_n < -1) {
|
||||||
|
@ -529,11 +557,10 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
slot_params generation_params;
|
slot_params generation_params;
|
||||||
|
|
||||||
// OAI-compat fields
|
// OAI-compat fields
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
bool oaicompat = false;
|
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||||
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_cmpl_id;
|
||||||
std::string oaicompat_cmpl_id;
|
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
return index;
|
return index;
|
||||||
|
@ -544,9 +571,16 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual json to_json() override {
|
virtual json to_json() override {
|
||||||
return oaicompat
|
switch (oaicompat) {
|
||||||
? (stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat())
|
case OAICOMPAT_TYPE_NONE:
|
||||||
: to_json_non_oaicompat();
|
return to_json_non_oaicompat();
|
||||||
|
case OAICOMPAT_TYPE_COMPLETION:
|
||||||
|
return to_json_oaicompat();
|
||||||
|
case OAICOMPAT_TYPE_CHAT:
|
||||||
|
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
json to_json_non_oaicompat() {
|
json to_json_non_oaicompat() {
|
||||||
|
@ -574,6 +608,50 @@ struct server_task_result_cmpl_final : server_task_result {
|
||||||
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
|
return response_fields.empty() ? res : json_get_nested_values(response_fields, res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
json to_json_oaicompat() {
|
||||||
|
std::time_t t = std::time(0);
|
||||||
|
json logprobs = json(nullptr); // OAI default to null
|
||||||
|
if (!stream && probs_output.size() > 0) {
|
||||||
|
logprobs = json{
|
||||||
|
{"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
json finish_reason = "length";
|
||||||
|
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||||
|
finish_reason = "stop";
|
||||||
|
}
|
||||||
|
json res = json {
|
||||||
|
{"choices", json::array({
|
||||||
|
json{
|
||||||
|
{"text", stream ? "" : content}, // in stream mode, content is already in last partial chunk
|
||||||
|
{"index", index},
|
||||||
|
{"logprobs", logprobs},
|
||||||
|
{"finish_reason", finish_reason},
|
||||||
|
}
|
||||||
|
})},
|
||||||
|
{"created", t},
|
||||||
|
{"model", oaicompat_model},
|
||||||
|
{"system_fingerprint", build_info},
|
||||||
|
{"object", "text_completion"},
|
||||||
|
{"usage", json {
|
||||||
|
{"completion_tokens", n_decoded},
|
||||||
|
{"prompt_tokens", n_prompt_tokens},
|
||||||
|
{"total_tokens", n_decoded + n_prompt_tokens}
|
||||||
|
}},
|
||||||
|
{"id", oaicompat_cmpl_id}
|
||||||
|
};
|
||||||
|
|
||||||
|
// extra fields for debugging purposes
|
||||||
|
if (verbose) {
|
||||||
|
res["__verbose"] = to_json_non_oaicompat();
|
||||||
|
}
|
||||||
|
if (timings.prompt_n >= 0) {
|
||||||
|
res.push_back({"timings", timings.to_json()});
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
json to_json_oaicompat_chat() {
|
json to_json_oaicompat_chat() {
|
||||||
std::string finish_reason = "length";
|
std::string finish_reason = "length";
|
||||||
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
|
||||||
|
@ -671,11 +749,10 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
result_timings timings;
|
result_timings timings;
|
||||||
|
|
||||||
// OAI-compat fields
|
// OAI-compat fields
|
||||||
bool verbose = false;
|
bool verbose = false;
|
||||||
bool oaicompat = false;
|
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||||
bool oaicompat_chat = true; // TODO: support oaicompat for non-chat
|
std::string oaicompat_model;
|
||||||
std::string oaicompat_model;
|
std::string oaicompat_cmpl_id;
|
||||||
std::string oaicompat_cmpl_id;
|
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
return index;
|
return index;
|
||||||
|
@ -686,7 +763,16 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual json to_json() override {
|
virtual json to_json() override {
|
||||||
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
|
switch (oaicompat) {
|
||||||
|
case OAICOMPAT_TYPE_NONE:
|
||||||
|
return to_json_non_oaicompat();
|
||||||
|
case OAICOMPAT_TYPE_COMPLETION:
|
||||||
|
return to_json_oaicompat();
|
||||||
|
case OAICOMPAT_TYPE_CHAT:
|
||||||
|
return to_json_oaicompat_chat();
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false && "Invalid oaicompat_type");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
json to_json_non_oaicompat() {
|
json to_json_non_oaicompat() {
|
||||||
|
@ -711,6 +797,41 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
}
|
}
|
||||||
|
|
||||||
json to_json_oaicompat() {
|
json to_json_oaicompat() {
|
||||||
|
std::time_t t = std::time(0);
|
||||||
|
json logprobs = json(nullptr); // OAI default to null
|
||||||
|
if (prob_output.probs.size() > 0) {
|
||||||
|
logprobs = json{
|
||||||
|
{"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
json res = json {
|
||||||
|
{"choices", json::array({
|
||||||
|
json{
|
||||||
|
{"text", content},
|
||||||
|
{"index", index},
|
||||||
|
{"logprobs", logprobs},
|
||||||
|
{"finish_reason", nullptr},
|
||||||
|
}
|
||||||
|
})},
|
||||||
|
{"created", t},
|
||||||
|
{"model", oaicompat_model},
|
||||||
|
{"system_fingerprint", build_info},
|
||||||
|
{"object", "text_completion"},
|
||||||
|
{"id", oaicompat_cmpl_id}
|
||||||
|
};
|
||||||
|
|
||||||
|
// extra fields for debugging purposes
|
||||||
|
if (verbose) {
|
||||||
|
res["__verbose"] = to_json_non_oaicompat();
|
||||||
|
}
|
||||||
|
if (timings.prompt_n >= 0) {
|
||||||
|
res.push_back({"timings", timings.to_json()});
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
json to_json_oaicompat_chat() {
|
||||||
bool first = n_decoded == 0;
|
bool first = n_decoded == 0;
|
||||||
std::time_t t = std::time(0);
|
std::time_t t = std::time(0);
|
||||||
json choices;
|
json choices;
|
||||||
|
@ -789,14 +910,16 @@ struct server_task_result_embd : server_task_result {
|
||||||
int32_t n_tokens;
|
int32_t n_tokens;
|
||||||
|
|
||||||
// OAI-compat fields
|
// OAI-compat fields
|
||||||
bool oaicompat = false;
|
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
|
||||||
|
|
||||||
virtual int get_index() override {
|
virtual int get_index() override {
|
||||||
return index;
|
return index;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual json to_json() override {
|
virtual json to_json() override {
|
||||||
return oaicompat ? to_json_oaicompat() : to_json_non_oaicompat();
|
return oaicompat == OAICOMPAT_TYPE_EMBEDDING
|
||||||
|
? to_json_oaicompat()
|
||||||
|
: to_json_non_oaicompat();
|
||||||
}
|
}
|
||||||
|
|
||||||
json to_json_non_oaicompat() {
|
json to_json_non_oaicompat() {
|
||||||
|
@ -1009,6 +1132,8 @@ struct server_slot {
|
||||||
|
|
||||||
common_speculative * spec = nullptr;
|
common_speculative * spec = nullptr;
|
||||||
|
|
||||||
|
std::vector<common_lora_adapter_container> lora;
|
||||||
|
|
||||||
// the index relative to completion multi-task request
|
// the index relative to completion multi-task request
|
||||||
size_t index = 0;
|
size_t index = 0;
|
||||||
|
|
||||||
|
@ -1090,6 +1215,11 @@ struct server_slot {
|
||||||
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
|
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool can_batch_with(server_slot & other_slot) {
|
||||||
|
return is_non_causal() == other_slot.is_non_causal()
|
||||||
|
&& are_lora_equal(lora, other_slot.lora);
|
||||||
|
}
|
||||||
|
|
||||||
bool has_budget(const common_params & global_params) {
|
bool has_budget(const common_params & global_params) {
|
||||||
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
if (params.n_predict == -1 && global_params.n_predict == -1) {
|
||||||
return true; // limitless
|
return true; // limitless
|
||||||
|
@ -1499,7 +1629,7 @@ struct server_context {
|
||||||
|
|
||||||
llama_model * model = nullptr;
|
llama_model * model = nullptr;
|
||||||
llama_context * ctx = nullptr;
|
llama_context * ctx = nullptr;
|
||||||
std::vector<common_lora_adapter_container> loras;
|
std::vector<common_lora_adapter_container> lora;
|
||||||
|
|
||||||
llama_model * model_dft = nullptr;
|
llama_model * model_dft = nullptr;
|
||||||
llama_context_params cparams_dft;
|
llama_context_params cparams_dft;
|
||||||
|
@ -1566,7 +1696,7 @@ struct server_context {
|
||||||
|
|
||||||
model = llama_init.model;
|
model = llama_init.model;
|
||||||
ctx = llama_init.context;
|
ctx = llama_init.context;
|
||||||
loras = llama_init.lora_adapters;
|
lora = llama_init.lora_adapters;
|
||||||
|
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
|
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
|
||||||
|
@ -1623,17 +1753,10 @@ struct server_context {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool validate_model_chat_template() const {
|
bool validate_builtin_chat_template() const {
|
||||||
std::vector<char> model_template(2048, 0); // longest known template is about 1200 bytes
|
llama_chat_message chat[] = {{"user", "test"}};
|
||||||
std::string template_key = "tokenizer.chat_template";
|
int32_t chat_res = llama_chat_apply_template(model, nullptr, chat, 1, true, nullptr, 0);
|
||||||
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
return chat_res > 0;
|
||||||
if (res >= 0) {
|
|
||||||
llama_chat_message chat[] = {{"user", "test"}};
|
|
||||||
std::string tmpl = std::string(model_template.data(), model_template.size());
|
|
||||||
int32_t chat_res = llama_chat_apply_template(model, tmpl.c_str(), chat, 1, true, nullptr, 0);
|
|
||||||
return chat_res > 0;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void init() {
|
void init() {
|
||||||
|
@ -1772,6 +1895,12 @@ struct server_context {
|
||||||
slot.params = std::move(task.params);
|
slot.params = std::move(task.params);
|
||||||
slot.prompt_tokens = std::move(task.prompt_tokens);
|
slot.prompt_tokens = std::move(task.prompt_tokens);
|
||||||
|
|
||||||
|
if (!are_lora_equal(task.params.lora, slot.lora)) {
|
||||||
|
// if lora is changed, we cannot reuse cached tokens
|
||||||
|
slot.cache_tokens.clear();
|
||||||
|
slot.lora = std::move(task.params.lora);
|
||||||
|
}
|
||||||
|
|
||||||
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
|
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
|
||||||
|
|
||||||
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
||||||
|
@ -1856,6 +1985,8 @@ struct server_context {
|
||||||
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
result.text_to_send = slot.generated_text.substr(pos, std::string::npos);
|
||||||
slot.n_sent_text += result.text_to_send.size();
|
slot.n_sent_text += result.text_to_send.size();
|
||||||
// add the token to slot queue and cache
|
// add the token to slot queue and cache
|
||||||
|
} else {
|
||||||
|
result.text_to_send = "";
|
||||||
}
|
}
|
||||||
|
|
||||||
slot.add_token(result);
|
slot.add_token(result);
|
||||||
|
@ -2042,7 +2173,6 @@ struct server_context {
|
||||||
|
|
||||||
res->verbose = slot.params.verbose;
|
res->verbose = slot.params.verbose;
|
||||||
res->oaicompat = slot.params.oaicompat;
|
res->oaicompat = slot.params.oaicompat;
|
||||||
res->oaicompat_chat = slot.params.oaicompat_chat;
|
|
||||||
res->oaicompat_model = slot.params.oaicompat_model;
|
res->oaicompat_model = slot.params.oaicompat_model;
|
||||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||||
|
|
||||||
|
@ -2083,7 +2213,6 @@ struct server_context {
|
||||||
res->verbose = slot.params.verbose;
|
res->verbose = slot.params.verbose;
|
||||||
res->stream = slot.params.stream;
|
res->stream = slot.params.stream;
|
||||||
res->oaicompat = slot.params.oaicompat;
|
res->oaicompat = slot.params.oaicompat;
|
||||||
res->oaicompat_chat = slot.params.oaicompat_chat;
|
|
||||||
res->oaicompat_model = slot.params.oaicompat_model;
|
res->oaicompat_model = slot.params.oaicompat_model;
|
||||||
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
res->oaicompat_cmpl_id = slot.params.oaicompat_cmpl_id;
|
||||||
|
|
||||||
|
@ -2463,7 +2592,7 @@ struct server_context {
|
||||||
} break;
|
} break;
|
||||||
case SERVER_TASK_TYPE_SET_LORA:
|
case SERVER_TASK_TYPE_SET_LORA:
|
||||||
{
|
{
|
||||||
common_lora_adapters_apply(ctx, loras);
|
lora = std::move(task.set_lora);
|
||||||
auto res = std::make_unique<server_task_result_apply_lora>();
|
auto res = std::make_unique<server_task_result_apply_lora>();
|
||||||
res->id = task.id;
|
res->id = task.id;
|
||||||
queue_results.send(std::move(res));
|
queue_results.send(std::move(res));
|
||||||
|
@ -2540,12 +2669,22 @@ struct server_context {
|
||||||
// start populating the batch for this iteration
|
// start populating the batch for this iteration
|
||||||
common_batch_clear(batch);
|
common_batch_clear(batch);
|
||||||
|
|
||||||
|
// track if given slot can be batched with slots already in the batch
|
||||||
|
server_slot * slot_batched = nullptr;
|
||||||
|
|
||||||
// frist, add sampled tokens from any ongoing sequences
|
// frist, add sampled tokens from any ongoing sequences
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
if (slot.state != SLOT_STATE_GENERATING) {
|
if (slot.state != SLOT_STATE_GENERATING) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check if we can batch this slot with the previous one
|
||||||
|
if (!slot_batched) {
|
||||||
|
slot_batched = &slot;
|
||||||
|
} else if (!slot_batched->can_batch_with(slot)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
slot.i_batch = batch.n_tokens;
|
slot.i_batch = batch.n_tokens;
|
||||||
|
|
||||||
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
|
||||||
|
@ -2564,15 +2703,18 @@ struct server_context {
|
||||||
int32_t n_batch = llama_n_batch(ctx);
|
int32_t n_batch = llama_n_batch(ctx);
|
||||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||||
|
|
||||||
// track if this is an embedding or non-embedding batch
|
|
||||||
// if we've added sampled tokens above, we are in non-embedding mode
|
|
||||||
// -1: none, 0: non-embedding, 1: embedding
|
|
||||||
// TODO: make enum
|
|
||||||
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
|
||||||
|
|
||||||
// next, batch any pending prompts without exceeding n_batch
|
// next, batch any pending prompts without exceeding n_batch
|
||||||
if (params_base.cont_batching || batch.n_tokens == 0) {
|
if (params_base.cont_batching || batch.n_tokens == 0) {
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
|
// check if we can batch this slot with the previous one
|
||||||
|
if (slot.is_processing()) {
|
||||||
|
if (!slot_batched) {
|
||||||
|
slot_batched = &slot;
|
||||||
|
} else if (!slot_batched->can_batch_with(slot)) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// this slot still has a prompt to be processed
|
// this slot still has a prompt to be processed
|
||||||
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
|
||||||
auto & prompt_tokens = slot.prompt_tokens;
|
auto & prompt_tokens = slot.prompt_tokens;
|
||||||
|
@ -2733,14 +2875,6 @@ struct server_context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// check that we are in the right batch_type, if not defer the slot
|
|
||||||
int slot_type = slot.is_non_causal();
|
|
||||||
if (batch_type == -1) {
|
|
||||||
batch_type = slot_type;
|
|
||||||
} else if (batch_type != slot_type) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// keep only the common part
|
// keep only the common part
|
||||||
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
|
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
|
||||||
// could not partially delete (likely using a non-Transformer model)
|
// could not partially delete (likely using a non-Transformer model)
|
||||||
|
@ -2808,8 +2942,12 @@ struct server_context {
|
||||||
|
|
||||||
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
||||||
|
|
||||||
// make sure we're in the right embedding mode
|
if (slot_batched) {
|
||||||
llama_set_embeddings(ctx, batch_type == 1);
|
// make sure we're in the right embedding mode
|
||||||
|
llama_set_embeddings(ctx, slot_batched->is_non_causal());
|
||||||
|
// apply lora, only need to do it once per batch
|
||||||
|
common_lora_adapters_apply(ctx, slot_batched->lora);
|
||||||
|
}
|
||||||
|
|
||||||
// process the created batch of tokens
|
// process the created batch of tokens
|
||||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||||
|
@ -3482,7 +3620,7 @@ int main(int argc, char ** argv) {
|
||||||
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
{ "default_generation_settings", ctx_server.default_generation_settings_for_props },
|
||||||
{ "total_slots", ctx_server.params_base.n_parallel },
|
{ "total_slots", ctx_server.params_base.n_parallel },
|
||||||
{ "model_path", ctx_server.params_base.model },
|
{ "model_path", ctx_server.params_base.model },
|
||||||
{ "chat_template", llama_get_chat_template(ctx_server.model) },
|
{ "chat_template", common_get_builtin_chat_template(ctx_server.model) },
|
||||||
{ "build_info", build_info },
|
{ "build_info", build_info },
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -3504,12 +3642,11 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// handle completion-like requests (completion, chat, infill)
|
// handle completion-like requests (completion, chat, infill)
|
||||||
// we can optionally provide a custom format for partial results and final results
|
// we can optionally provide a custom format for partial results and final results
|
||||||
const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](
|
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
|
||||||
server_task_type type,
|
server_task_type type,
|
||||||
json & data,
|
json & data,
|
||||||
httplib::Response & res,
|
httplib::Response & res,
|
||||||
bool oaicompat = false,
|
oaicompat_type oaicompat) {
|
||||||
bool oaicompat_chat = false) {
|
|
||||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||||
|
|
||||||
if (ctx_server.params_base.embedding) {
|
if (ctx_server.params_base.embedding) {
|
||||||
|
@ -3530,13 +3667,17 @@ int main(int argc, char ** argv) {
|
||||||
task.index = i;
|
task.index = i;
|
||||||
|
|
||||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||||
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
|
task.params = server_task::params_from_json_cmpl(
|
||||||
|
ctx_server.model,
|
||||||
|
ctx_server.ctx,
|
||||||
|
ctx_server.params_base,
|
||||||
|
ctx_server.lora,
|
||||||
|
data);
|
||||||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
task.id_selected_slot = json_value(data, "id_slot", -1);
|
||||||
|
|
||||||
// OAI-compat
|
// OAI-compat
|
||||||
task.params.oaicompat = oaicompat;
|
task.params.oaicompat = oaicompat;
|
||||||
task.params.oaicompat_chat = oaicompat_chat;
|
task.params.oaicompat_cmpl_id = completion_id;
|
||||||
task.params.oaicompat_cmpl_id = completion_id;
|
|
||||||
// oaicompat_model is already populated by params_from_json_cmpl
|
// oaicompat_model is already populated by params_from_json_cmpl
|
||||||
|
|
||||||
tasks.push_back(task);
|
tasks.push_back(task);
|
||||||
|
@ -3587,7 +3728,7 @@ int main(int argc, char ** argv) {
|
||||||
}, [&](const json & error_data) {
|
}, [&](const json & error_data) {
|
||||||
server_sent_event(sink, "error", error_data);
|
server_sent_event(sink, "error", error_data);
|
||||||
});
|
});
|
||||||
if (oaicompat) {
|
if (oaicompat != OAICOMPAT_TYPE_NONE) {
|
||||||
static const std::string ev_done = "data: [DONE]\n\n";
|
static const std::string ev_done = "data: [DONE]\n\n";
|
||||||
sink.write(ev_done.data(), ev_done.size());
|
sink.write(ev_done.data(), ev_done.size());
|
||||||
}
|
}
|
||||||
|
@ -3603,17 +3744,25 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_completions = [&handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
return handle_completions_generic(
|
return handle_completions_impl(
|
||||||
SERVER_TASK_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
data,
|
data,
|
||||||
res,
|
res,
|
||||||
/* oaicompat */ false,
|
OAICOMPAT_TYPE_NONE);
|
||||||
/* oaicompat_chat */ false);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
json data = oaicompat_completion_params_parse(json::parse(req.body));
|
||||||
|
return handle_completions_impl(
|
||||||
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
|
data,
|
||||||
|
res,
|
||||||
|
OAICOMPAT_TYPE_COMPLETION);
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto handle_infill = [&ctx_server, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||||
// check model compatibility
|
// check model compatibility
|
||||||
std::string err;
|
std::string err;
|
||||||
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
if (llama_token_fim_pre(ctx_server.model) == LLAMA_TOKEN_NULL) {
|
||||||
|
@ -3682,22 +3831,25 @@ int main(int argc, char ** argv) {
|
||||||
tokenized_prompts[0]
|
tokenized_prompts[0]
|
||||||
);
|
);
|
||||||
|
|
||||||
return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res);
|
return handle_completions_impl(
|
||||||
|
SERVER_TASK_TYPE_INFILL,
|
||||||
|
data,
|
||||||
|
res,
|
||||||
|
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_generic](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) {
|
||||||
if (ctx_server.params_base.embedding) {
|
if (ctx_server.params_base.embedding) {
|
||||||
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
json data = oaicompat_chat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||||
return handle_completions_generic(
|
return handle_completions_impl(
|
||||||
SERVER_TASK_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
data,
|
data,
|
||||||
res,
|
res,
|
||||||
/* oaicompat */ true,
|
OAICOMPAT_TYPE_CHAT);
|
||||||
/* oaicompat_chat */ true);
|
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
const auto handle_models = [¶ms, &ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) {
|
||||||
|
@ -3770,10 +3922,10 @@ int main(int argc, char ** argv) {
|
||||||
res_ok(res, data);
|
res_ok(res, data);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, bool oaicompat) {
|
const auto handle_embeddings_impl = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) {
|
||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
|
|
||||||
if (oaicompat && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
|
||||||
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
res_error(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -3783,7 +3935,7 @@ int main(int argc, char ** argv) {
|
||||||
if (body.count("input") != 0) {
|
if (body.count("input") != 0) {
|
||||||
prompt = body.at("input");
|
prompt = body.at("input");
|
||||||
} else if (body.contains("content")) {
|
} else if (body.contains("content")) {
|
||||||
oaicompat = false;
|
oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
|
||||||
prompt = body.at("content");
|
prompt = body.at("content");
|
||||||
} else {
|
} else {
|
||||||
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
|
||||||
|
@ -3852,16 +4004,18 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// write JSON response
|
// write JSON response
|
||||||
json root = oaicompat ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses);
|
json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
|
||||||
|
? format_embeddings_response_oaicompat(body, responses, use_base64)
|
||||||
|
: json(responses);
|
||||||
res_ok(res, root);
|
res_ok(res, root);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
||||||
handle_embeddings_impl(req, res, false);
|
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) {
|
||||||
handle_embeddings_impl(req, res, true);
|
handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING);
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_rerank = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
@ -3944,8 +4098,8 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
||||||
json result = json::array();
|
json result = json::array();
|
||||||
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
|
for (size_t i = 0; i < ctx_server.lora.size(); ++i) {
|
||||||
auto & lora = ctx_server.loras[i];
|
auto & lora = ctx_server.lora[i];
|
||||||
result.push_back({
|
result.push_back({
|
||||||
{"id", i},
|
{"id", i},
|
||||||
{"path", lora.path},
|
{"path", lora.path},
|
||||||
|
@ -3957,27 +4111,14 @@ int main(int argc, char ** argv) {
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
|
||||||
const std::vector<json> body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
int max_idx = ctx_server.loras.size();
|
if (!body.is_array()) {
|
||||||
|
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
|
||||||
// clear existing value
|
return;
|
||||||
for (auto & lora : ctx_server.loras) {
|
|
||||||
lora.scale = 0.0f;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// set value
|
|
||||||
for (auto entry : body) {
|
|
||||||
int id = entry.at("id");
|
|
||||||
float scale = entry.at("scale");
|
|
||||||
if (0 <= id && id < max_idx) {
|
|
||||||
ctx_server.loras[id].scale = scale;
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("invalid adapter id");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
server_task task(SERVER_TASK_TYPE_SET_LORA);
|
server_task task(SERVER_TASK_TYPE_SET_LORA);
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = ctx_server.queue_tasks.get_new_id();
|
||||||
|
task.set_lora = parse_lora_request(ctx_server.lora, body);
|
||||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
ctx_server.queue_results.add_waiting_task_id(task.id);
|
||||||
ctx_server.queue_tasks.post(task);
|
ctx_server.queue_tasks.post(task);
|
||||||
|
|
||||||
|
@ -4031,7 +4172,7 @@ int main(int argc, char ** argv) {
|
||||||
svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
|
svr->Get ("/v1/models", handle_models); // public endpoint (no API key check)
|
||||||
svr->Post("/completion", handle_completions); // legacy
|
svr->Post("/completion", handle_completions); // legacy
|
||||||
svr->Post("/completions", handle_completions);
|
svr->Post("/completions", handle_completions);
|
||||||
svr->Post("/v1/completions", handle_completions);
|
svr->Post("/v1/completions", handle_completions_oai);
|
||||||
svr->Post("/chat/completions", handle_chat_completions);
|
svr->Post("/chat/completions", handle_chat_completions);
|
||||||
svr->Post("/v1/chat/completions", handle_chat_completions);
|
svr->Post("/v1/chat/completions", handle_chat_completions);
|
||||||
svr->Post("/infill", handle_infill);
|
svr->Post("/infill", handle_infill);
|
||||||
|
@ -4111,14 +4252,16 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
// if a custom chat template is not supplied, we will use the one that comes with the model (if any)
|
||||||
if (params.chat_template.empty()) {
|
if (params.chat_template.empty()) {
|
||||||
if (!ctx_server.validate_model_chat_template()) {
|
if (!ctx_server.validate_builtin_chat_template()) {
|
||||||
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
LOG_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__);
|
||||||
params.chat_template = "chatml";
|
params.chat_template = "chatml";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// print sample chat example to make it clear which template is used
|
// print sample chat example to make it clear which template is used
|
||||||
LOG_INF("%s: chat template, built_in: %d, chat_example: '%s'\n", __func__, params.chat_template.empty(), common_chat_format_example(ctx_server.model, params.chat_template).c_str());
|
LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__,
|
||||||
|
params.chat_template.empty() ? "(built-in)" : params.chat_template.c_str(),
|
||||||
|
common_chat_format_example(ctx_server.model, params.chat_template).c_str());
|
||||||
|
|
||||||
ctx_server.queue_tasks.on_new_task(std::bind(
|
ctx_server.queue_tasks.on_new_task(std::bind(
|
||||||
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
&server_context::process_single_task, &ctx_server, std::placeholders::_1));
|
||||||
|
|
|
@ -5,3 +5,4 @@ numpy~=1.26.4
|
||||||
openai~=1.55.3
|
openai~=1.55.3
|
||||||
prometheus-client~=0.20.0
|
prometheus-client~=0.20.0
|
||||||
requests~=2.32.3
|
requests~=2.32.3
|
||||||
|
wget~=3.2
|
||||||
|
|
|
@ -83,7 +83,7 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
|
||||||
def test_chat_completion_with_openai_library():
|
def test_chat_completion_with_openai_library():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||||
res = client.chat.completions.create(
|
res = client.chat.completions.create(
|
||||||
model="gpt-3.5-turbo-instruct",
|
model="gpt-3.5-turbo-instruct",
|
||||||
messages=[
|
messages=[
|
||||||
|
@ -100,6 +100,23 @@ def test_chat_completion_with_openai_library():
|
||||||
assert match_regex("(Suddenly)+", res.choices[0].message.content)
|
assert match_regex("(Suddenly)+", res.choices[0].message.content)
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_template():
|
||||||
|
global server
|
||||||
|
server.chat_template = "llama3"
|
||||||
|
server.debug = True # to get the "__verbose" object in the response
|
||||||
|
server.start()
|
||||||
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
"max_tokens": 8,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "Book"},
|
||||||
|
{"role": "user", "content": "What is the best book"},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert "__verbose" in res.body
|
||||||
|
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
|
@pytest.mark.parametrize("response_format,n_predicted,re_content", [
|
||||||
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
|
({"type": "json_object", "schema": {"const": "42"}}, 6, "\"42\""),
|
||||||
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
|
({"type": "json_object", "schema": {"items": [{"type": "integer"}]}}, 10, "[ -3000 ]"),
|
||||||
|
@ -170,7 +187,7 @@ def test_chat_completion_with_timings_per_token():
|
||||||
def test_logprobs():
|
def test_logprobs():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||||
res = client.chat.completions.create(
|
res = client.chat.completions.create(
|
||||||
model="gpt-3.5-turbo-instruct",
|
model="gpt-3.5-turbo-instruct",
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
|
@ -197,7 +214,7 @@ def test_logprobs():
|
||||||
def test_logprobs_stream():
|
def test_logprobs_stream():
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}")
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||||
res = client.chat.completions.create(
|
res = client.chat.completions.create(
|
||||||
model="gpt-3.5-turbo-instruct",
|
model="gpt-3.5-turbo-instruct",
|
||||||
temperature=0.0,
|
temperature=0.0,
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
import time
|
import time
|
||||||
|
from openai import OpenAI
|
||||||
from utils import *
|
from utils import *
|
||||||
|
|
||||||
server = ServerPreset.tinyllama2()
|
server = ServerPreset.tinyllama2()
|
||||||
|
@ -85,6 +86,40 @@ def test_completion_stream_vs_non_stream():
|
||||||
assert content_stream == res_non_stream.body["content"]
|
assert content_stream == res_non_stream.body["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_stream_with_openai_library():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||||
|
res = client.completions.create(
|
||||||
|
model="davinci-002",
|
||||||
|
prompt="I believe the meaning of life is",
|
||||||
|
max_tokens=8,
|
||||||
|
)
|
||||||
|
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
|
||||||
|
assert res.choices[0].finish_reason == "length"
|
||||||
|
assert res.choices[0].text is not None
|
||||||
|
assert match_regex("(going|bed)+", res.choices[0].text)
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_with_openai_library():
|
||||||
|
global server
|
||||||
|
server.start()
|
||||||
|
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||||
|
res = client.completions.create(
|
||||||
|
model="davinci-002",
|
||||||
|
prompt="I believe the meaning of life is",
|
||||||
|
max_tokens=8,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
output_text = ''
|
||||||
|
for data in res:
|
||||||
|
choice = data.choices[0]
|
||||||
|
if choice.finish_reason is None:
|
||||||
|
assert choice.text is not None
|
||||||
|
output_text += choice.text
|
||||||
|
assert match_regex("(going|bed)+", output_text)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("n_slots", [1, 2])
|
@pytest.mark.parametrize("n_slots", [1, 2])
|
||||||
def test_consistent_result_same_seed(n_slots: int):
|
def test_consistent_result_same_seed(n_slots: int):
|
||||||
global server
|
global server
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import pytest
|
import pytest
|
||||||
import os
|
|
||||||
from utils import *
|
from utils import *
|
||||||
|
|
||||||
server = ServerPreset.stories15m_moe()
|
server = ServerPreset.stories15m_moe()
|
||||||
|
@ -10,15 +9,7 @@ LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.stories15m_moe()
|
server = ServerPreset.stories15m_moe()
|
||||||
# download lora file if needed
|
server.lora_files = [download_file(LORA_FILE_URL)]
|
||||||
file_name = LORA_FILE_URL.split('/').pop()
|
|
||||||
lora_file = f'../../../{file_name}'
|
|
||||||
if not os.path.exists(lora_file):
|
|
||||||
print(f"Downloading {LORA_FILE_URL} to {lora_file}")
|
|
||||||
with open(lora_file, 'wb') as f:
|
|
||||||
f.write(requests.get(LORA_FILE_URL).content)
|
|
||||||
print(f"Done downloading lora file")
|
|
||||||
server.lora_files = [lora_file]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("scale,re_content", [
|
@pytest.mark.parametrize("scale,re_content", [
|
||||||
|
@ -40,3 +31,85 @@ def test_lora(scale: float, re_content: str):
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex(re_content, res.body["content"])
|
assert match_regex(re_content, res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_per_request():
|
||||||
|
global server
|
||||||
|
server.n_slots = 4
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
# running the same prompt with different lora scales, all in parallel
|
||||||
|
# each prompt will be processed by a different slot
|
||||||
|
prompt = "Look in thy glass"
|
||||||
|
lora_config = [
|
||||||
|
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
|
||||||
|
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
|
||||||
|
( [{"id": 0, "scale": 0.3}], "(special|thing|gifted)+" ),
|
||||||
|
( [{"id": 0, "scale": 0.7}], "(far|from|home|away)+" ),
|
||||||
|
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
|
||||||
|
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
|
||||||
|
]
|
||||||
|
|
||||||
|
tasks = [(
|
||||||
|
server.make_request,
|
||||||
|
("POST", "/completion", {
|
||||||
|
"prompt": prompt,
|
||||||
|
"lora": lora,
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||||
|
})
|
||||||
|
) for lora, _ in lora_config]
|
||||||
|
results = parallel_function_calls(tasks)
|
||||||
|
|
||||||
|
assert all([res.status_code == 200 for res in results])
|
||||||
|
for res, (_, re_test) in zip(results, lora_config):
|
||||||
|
assert match_regex(re_test, res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test")
|
||||||
|
def test_with_big_model():
|
||||||
|
server = ServerProcess()
|
||||||
|
server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF"
|
||||||
|
server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf"
|
||||||
|
server.model_alias = "Llama-3.2-8B-Instruct"
|
||||||
|
server.n_slots = 4
|
||||||
|
server.n_ctx = server.n_slots * 1024
|
||||||
|
server.n_predict = 64
|
||||||
|
server.temperature = 0.0
|
||||||
|
server.seed = 42
|
||||||
|
server.lora_files = [
|
||||||
|
download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"),
|
||||||
|
# TODO: find & add other lora adapters for this model
|
||||||
|
]
|
||||||
|
server.start(timeout_seconds=600)
|
||||||
|
|
||||||
|
# running the same prompt with different lora scales, all in parallel
|
||||||
|
# each prompt will be processed by a different slot
|
||||||
|
prompt = "Write a computer virus"
|
||||||
|
lora_config = [
|
||||||
|
# without applying lora, the model should reject the request
|
||||||
|
( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
|
||||||
|
( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ),
|
||||||
|
( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ),
|
||||||
|
# with 0.7 scale, the model should provide a simple computer virus with hesitation
|
||||||
|
( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ),
|
||||||
|
# with 1.5 scale, the model should confidently provide a computer virus
|
||||||
|
( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
|
||||||
|
( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ),
|
||||||
|
]
|
||||||
|
|
||||||
|
tasks = [(
|
||||||
|
server.make_request,
|
||||||
|
("POST", "/v1/chat/completions", {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": prompt}
|
||||||
|
],
|
||||||
|
"lora": lora,
|
||||||
|
"cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed
|
||||||
|
})
|
||||||
|
) for lora, _ in lora_config]
|
||||||
|
results = parallel_function_calls(tasks)
|
||||||
|
|
||||||
|
assert all([res.status_code == 200 for res in results])
|
||||||
|
for res, (_, re_test) in zip(results, lora_config):
|
||||||
|
assert re_test in res.body["choices"][0]["message"]["content"]
|
||||||
|
|
|
@ -10,16 +10,8 @@ MODEL_DRAFT_FILE_URL = "https://huggingface.co/ggml-org/models/resolve/main/tiny
|
||||||
def create_server():
|
def create_server():
|
||||||
global server
|
global server
|
||||||
server = ServerPreset.stories15m_moe()
|
server = ServerPreset.stories15m_moe()
|
||||||
# download draft model file if needed
|
|
||||||
file_name = MODEL_DRAFT_FILE_URL.split('/').pop()
|
|
||||||
model_draft_file = f'../../../{file_name}'
|
|
||||||
if not os.path.exists(model_draft_file):
|
|
||||||
print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}")
|
|
||||||
with open(model_draft_file, 'wb') as f:
|
|
||||||
f.write(requests.get(MODEL_DRAFT_FILE_URL).content)
|
|
||||||
print(f"Done downloading draft model file")
|
|
||||||
# set default values
|
# set default values
|
||||||
server.model_draft = model_draft_file
|
server.model_draft = download_file(MODEL_DRAFT_FILE_URL)
|
||||||
server.draft_min = 4
|
server.draft_min = 4
|
||||||
server.draft_max = 8
|
server.draft_max = 8
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ from typing import (
|
||||||
Set,
|
Set,
|
||||||
)
|
)
|
||||||
from re import RegexFlag
|
from re import RegexFlag
|
||||||
|
import wget
|
||||||
|
|
||||||
|
|
||||||
class ServerResponse:
|
class ServerResponse:
|
||||||
|
@ -74,6 +75,7 @@ class ServerProcess:
|
||||||
draft_min: int | None = None
|
draft_min: int | None = None
|
||||||
draft_max: int | None = None
|
draft_max: int | None = None
|
||||||
no_webui: bool | None = None
|
no_webui: bool | None = None
|
||||||
|
chat_template: str | None = None
|
||||||
|
|
||||||
# session variables
|
# session variables
|
||||||
process: subprocess.Popen | None = None
|
process: subprocess.Popen | None = None
|
||||||
|
@ -164,6 +166,8 @@ class ServerProcess:
|
||||||
server_args.extend(["--draft-min", self.draft_min])
|
server_args.extend(["--draft-min", self.draft_min])
|
||||||
if self.no_webui:
|
if self.no_webui:
|
||||||
server_args.append("--no-webui")
|
server_args.append("--no-webui")
|
||||||
|
if self.chat_template:
|
||||||
|
server_args.extend(["--chat-template", self.chat_template])
|
||||||
|
|
||||||
args = [str(arg) for arg in [server_path, *server_args]]
|
args = [str(arg) for arg in [server_path, *server_args]]
|
||||||
print(f"bench: starting server with: {' '.join(args)}")
|
print(f"bench: starting server with: {' '.join(args)}")
|
||||||
|
@ -378,5 +382,25 @@ def match_regex(regex: str, text: str) -> bool:
|
||||||
is not None
|
is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def download_file(url: str, output_file_path: str | None = None) -> str:
|
||||||
|
"""
|
||||||
|
Download a file from a URL to a local path. If the file already exists, it will not be downloaded again.
|
||||||
|
|
||||||
|
output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory.
|
||||||
|
|
||||||
|
Returns the local path of the downloaded file.
|
||||||
|
"""
|
||||||
|
file_name = url.split('/').pop()
|
||||||
|
output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path
|
||||||
|
if not os.path.exists(output_file):
|
||||||
|
print(f"Downloading {url} to {output_file}")
|
||||||
|
wget.download(url, out=output_file)
|
||||||
|
print(f"Done downloading to {output_file}")
|
||||||
|
else:
|
||||||
|
print(f"File already exists at {output_file}")
|
||||||
|
return output_file
|
||||||
|
|
||||||
|
|
||||||
def is_slow_test_allowed():
|
def is_slow_test_allowed():
|
||||||
return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"
|
return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON"
|
||||||
|
|
|
@ -382,19 +382,6 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
||||||
return formatted_chat;
|
return formatted_chat;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string llama_get_chat_template(const struct llama_model * model) {
|
|
||||||
std::string template_key = "tokenizer.chat_template";
|
|
||||||
// call with NULL buffer to get the total size of the string
|
|
||||||
int32_t res = llama_model_meta_val_str(model, template_key.c_str(), NULL, 0);
|
|
||||||
if (res < 2) {
|
|
||||||
return "";
|
|
||||||
} else {
|
|
||||||
std::vector<char> model_template(res + 1, 0);
|
|
||||||
llama_model_meta_val_str(model, template_key.c_str(), model_template.data(), model_template.size());
|
|
||||||
return std::string(model_template.data(), model_template.size() - 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// base64 utils (TODO: move to common in the future)
|
// base64 utils (TODO: move to common in the future)
|
||||||
//
|
//
|
||||||
|
@ -549,10 +536,49 @@ static bool server_sent_event(httplib::DataSink & sink, const char * event, cons
|
||||||
// OAI utils
|
// OAI utils
|
||||||
//
|
//
|
||||||
|
|
||||||
static json oaicompat_completion_params_parse(
|
static json oaicompat_completion_params_parse(const json & body) {
|
||||||
const struct llama_model * model,
|
json llama_params;
|
||||||
const json & body, /* openai api json semantics */
|
|
||||||
const std::string & chat_template) {
|
if (!body.contains("prompt")) {
|
||||||
|
throw std::runtime_error("\"prompt\" is required");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle "stop" field
|
||||||
|
if (body.contains("stop") && body.at("stop").is_string()) {
|
||||||
|
llama_params["stop"] = json::array({body.at("stop").get<std::string>()});
|
||||||
|
} else {
|
||||||
|
llama_params["stop"] = json_value(body, "stop", json::array());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle "n" field
|
||||||
|
int n_choices = json_value(body, "n", 1);
|
||||||
|
if (n_choices != 1) {
|
||||||
|
throw std::runtime_error("Only one completion choice is allowed");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Params supported by OAI but unsupported by llama.cpp
|
||||||
|
static const std::vector<std::string> unsupported_params { "best_of", "echo", "suffix" };
|
||||||
|
for (const auto & param : unsupported_params) {
|
||||||
|
if (body.contains(param)) {
|
||||||
|
throw std::runtime_error("Unsupported param: " + param);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy remaining properties to llama_params
|
||||||
|
for (const auto & item : body.items()) {
|
||||||
|
// Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens"
|
||||||
|
if (!llama_params.contains(item.key()) || item.key() == "n_predict") {
|
||||||
|
llama_params[item.key()] = item.value();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return llama_params;
|
||||||
|
}
|
||||||
|
|
||||||
|
static json oaicompat_chat_completion_params_parse(
|
||||||
|
const struct llama_model * model,
|
||||||
|
const json & body, /* openai api json semantics */
|
||||||
|
const std::string & chat_template) {
|
||||||
json llama_params;
|
json llama_params;
|
||||||
|
|
||||||
// Apply chat template to the list of messages
|
// Apply chat template to the list of messages
|
||||||
|
@ -771,3 +797,44 @@ static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx
|
||||||
|
|
||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool are_lora_equal(
|
||||||
|
const std::vector<common_lora_adapter_container> & l1,
|
||||||
|
const std::vector<common_lora_adapter_container> & l2) {
|
||||||
|
if (l1.size() != l2.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < l1.size(); ++i) {
|
||||||
|
// we don't check lora.path to reduce the time complexity
|
||||||
|
if (l1[i].scale != l2[i].scale || l1[i].adapter != l2[i].adapter) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse lora config from JSON request, returned a copy of base_lora with updated scale
|
||||||
|
static std::vector<common_lora_adapter_container> parse_lora_request(
|
||||||
|
const std::vector<common_lora_adapter_container> & base_lora,
|
||||||
|
const json & data) {
|
||||||
|
std::vector<common_lora_adapter_container> lora(base_lora);
|
||||||
|
int max_idx = lora.size();
|
||||||
|
|
||||||
|
// clear existing value
|
||||||
|
for (auto & entry : lora) {
|
||||||
|
entry.scale = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// set value
|
||||||
|
for (const auto & entry : data) {
|
||||||
|
int id = json_value(entry, "id", -1);
|
||||||
|
float scale = json_value(entry, "scale", 0.0f);
|
||||||
|
if (0 <= id && id < max_idx) {
|
||||||
|
lora[id].scale = scale;
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("invalid adapter id");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return lora;
|
||||||
|
}
|
||||||
|
|
|
@ -209,9 +209,12 @@ static inline __m256i sum_i16_pairs_int32x8(const __m256i x) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
|
static inline __m256i mul_sum_us8_pairs_int32x8(const __m256i ax, const __m256i sy) {
|
||||||
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
||||||
const __m256i zero = _mm256_setzero_si256();
|
const __m256i zero = _mm256_setzero_si256();
|
||||||
return _mm256_dpbusd_epi32(zero, ax, sy);
|
return _mm256_dpbusd_epi32(zero, ax, sy);
|
||||||
|
#elif defined(__AVXVNNI__)
|
||||||
|
const __m256i zero = _mm256_setzero_si256();
|
||||||
|
return _mm256_dpbusd_avx_epi32(zero, ax, sy);
|
||||||
#else
|
#else
|
||||||
// Perform multiplication and create 16-bit values
|
// Perform multiplication and create 16-bit values
|
||||||
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
||||||
|
|
|
@ -104,10 +104,14 @@ static inline __m256 sum_i16_pairs_float(const __m256i x) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
|
static inline __m256 mul_sum_us8_pairs_float(const __m256i ax, const __m256i sy) {
|
||||||
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
||||||
const __m256i zero = _mm256_setzero_si256();
|
const __m256i zero = _mm256_setzero_si256();
|
||||||
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
|
const __m256i summed_pairs = _mm256_dpbusd_epi32(zero, ax, sy);
|
||||||
return _mm256_cvtepi32_ps(summed_pairs);
|
return _mm256_cvtepi32_ps(summed_pairs);
|
||||||
|
#elif defined(__AVXVNNI__)
|
||||||
|
const __m256i zero = _mm256_setzero_si256();
|
||||||
|
const __m256i summed_pairs = _mm256_dpbusd_avx_epi32(zero, ax, sy);
|
||||||
|
return _mm256_cvtepi32_ps(summed_pairs);
|
||||||
#else
|
#else
|
||||||
// Perform multiplication and create 16-bit values
|
// Perform multiplication and create 16-bit values
|
||||||
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
const __m256i dot = _mm256_maddubs_epi16(ax, sy);
|
||||||
|
|
|
@ -1000,8 +1000,10 @@ class tinyBLAS_Q0_AVX {
|
||||||
|
|
||||||
inline __m256 updot(__m256i u, __m256i s) {
|
inline __m256 updot(__m256i u, __m256i s) {
|
||||||
__m256i res;
|
__m256i res;
|
||||||
#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__))
|
#if defined(__AVX512VNNI__) && defined(__AVX512VL__)
|
||||||
res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
|
res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s);
|
||||||
|
#elif defined(__AVXVNNI__)
|
||||||
|
res = _mm256_dpbusd_avx_epi32(_mm256_setzero_si256(), u, s);
|
||||||
#else
|
#else
|
||||||
res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
|
res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s));
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -2744,13 +2744,13 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
|
||||||
cl_image_format img_fmt_1d;
|
cl_image_format img_fmt_1d;
|
||||||
cl_image_desc img_desc_1d;
|
cl_image_desc img_desc_1d;
|
||||||
cl_buffer_region region;
|
cl_buffer_region region;
|
||||||
cl_mem A_image1d;
|
cl_mem A_image1d = nullptr;
|
||||||
cl_mem B_image1d;
|
cl_mem B_image1d = nullptr;
|
||||||
cl_mem B_sub_buffer;
|
cl_mem B_sub_buffer = nullptr;
|
||||||
cl_mem C_d;
|
cl_mem C_d = nullptr;
|
||||||
// for B transpose
|
// for B transpose
|
||||||
cl_mem B_d;
|
cl_mem B_d = nullptr;
|
||||||
cl_mem B_d_input_image;
|
cl_mem B_d_input_image = nullptr;
|
||||||
// <--------------------------------------------> //
|
// <--------------------------------------------> //
|
||||||
|
|
||||||
// define matrix dimensions
|
// define matrix dimensions
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1560,47 +1560,47 @@ const uint64_t matmul_id_iq4_nl_f16_f16acc_coopmat_len = 16892;
|
||||||
extern unsigned char matmul_id_iq4_nl_f16_aligned_f16acc_coopmat_data[17856];
|
extern unsigned char matmul_id_iq4_nl_f16_aligned_f16acc_coopmat_data[17856];
|
||||||
const uint64_t matmul_id_iq4_nl_f16_aligned_f16acc_coopmat_len = 17856;
|
const uint64_t matmul_id_iq4_nl_f16_aligned_f16acc_coopmat_len = 17856;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_f32_f32_f32_data[13840];
|
extern unsigned char mul_mat_vec_f32_f32_f32_data[16528];
|
||||||
const uint64_t mul_mat_vec_f32_f32_f32_len = 13840;
|
const uint64_t mul_mat_vec_f32_f32_f32_len = 16528;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_f32_f16_f32_data[14068];
|
extern unsigned char mul_mat_vec_f32_f16_f32_data[16756];
|
||||||
const uint64_t mul_mat_vec_f32_f16_f32_len = 14068;
|
const uint64_t mul_mat_vec_f32_f16_f32_len = 16756;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_id_f32_f32_data[13336];
|
extern unsigned char mul_mat_vec_id_f32_f32_data[16384];
|
||||||
const uint64_t mul_mat_vec_id_f32_f32_len = 13336;
|
const uint64_t mul_mat_vec_id_f32_f32_len = 16384;
|
||||||
|
|
||||||
extern unsigned char dequant_f32_data[3224];
|
extern unsigned char dequant_f32_data[3224];
|
||||||
const uint64_t dequant_f32_len = 3224;
|
const uint64_t dequant_f32_len = 3224;
|
||||||
|
|
||||||
extern unsigned char get_rows_f32_data[3088];
|
extern unsigned char get_rows_f32_data[3312];
|
||||||
const uint64_t get_rows_f32_len = 3088;
|
const uint64_t get_rows_f32_len = 3312;
|
||||||
|
|
||||||
extern unsigned char get_rows_f32_f32_data[3036];
|
extern unsigned char get_rows_f32_f32_data[3260];
|
||||||
const uint64_t get_rows_f32_f32_len = 3036;
|
const uint64_t get_rows_f32_f32_len = 3260;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_f16_f32_f32_data[14068];
|
extern unsigned char mul_mat_vec_f16_f32_f32_data[16756];
|
||||||
const uint64_t mul_mat_vec_f16_f32_f32_len = 14068;
|
const uint64_t mul_mat_vec_f16_f32_f32_len = 16756;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_f16_f16_f32_data[14260];
|
extern unsigned char mul_mat_vec_f16_f16_f32_data[16948];
|
||||||
const uint64_t mul_mat_vec_f16_f16_f32_len = 14260;
|
const uint64_t mul_mat_vec_f16_f16_f32_len = 16948;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_id_f16_f32_data[13564];
|
extern unsigned char mul_mat_vec_id_f16_f32_data[16612];
|
||||||
const uint64_t mul_mat_vec_id_f16_f32_len = 13564;
|
const uint64_t mul_mat_vec_id_f16_f32_len = 16612;
|
||||||
|
|
||||||
extern unsigned char get_rows_f16_data[3056];
|
extern unsigned char get_rows_f16_data[3280];
|
||||||
const uint64_t get_rows_f16_len = 3056;
|
const uint64_t get_rows_f16_len = 3280;
|
||||||
|
|
||||||
extern unsigned char get_rows_f16_f32_data[3088];
|
extern unsigned char get_rows_f16_f32_data[3312];
|
||||||
const uint64_t get_rows_f16_f32_len = 3088;
|
const uint64_t get_rows_f16_f32_len = 3312;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q4_0_f32_f32_data[19240];
|
extern unsigned char mul_mat_vec_q4_0_f32_f32_data[21928];
|
||||||
const uint64_t mul_mat_vec_q4_0_f32_f32_len = 19240;
|
const uint64_t mul_mat_vec_q4_0_f32_f32_len = 21928;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q4_0_f16_f32_data[20032];
|
extern unsigned char mul_mat_vec_q4_0_f16_f32_data[22720];
|
||||||
const uint64_t mul_mat_vec_q4_0_f16_f32_len = 20032;
|
const uint64_t mul_mat_vec_q4_0_f16_f32_len = 22720;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_id_q4_0_f32_data[18736];
|
extern unsigned char mul_mat_vec_id_q4_0_f32_data[21784];
|
||||||
const uint64_t mul_mat_vec_id_q4_0_f32_len = 18736;
|
const uint64_t mul_mat_vec_id_q4_0_f32_len = 21784;
|
||||||
|
|
||||||
extern unsigned char dequant_q4_0_data[5188];
|
extern unsigned char dequant_q4_0_data[5188];
|
||||||
const uint64_t dequant_q4_0_len = 5188;
|
const uint64_t dequant_q4_0_len = 5188;
|
||||||
|
@ -1611,14 +1611,14 @@ const uint64_t get_rows_q4_0_len = 3764;
|
||||||
extern unsigned char get_rows_q4_0_f32_data[3748];
|
extern unsigned char get_rows_q4_0_f32_data[3748];
|
||||||
const uint64_t get_rows_q4_0_f32_len = 3748;
|
const uint64_t get_rows_q4_0_f32_len = 3748;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q4_1_f32_f32_data[21492];
|
extern unsigned char mul_mat_vec_q4_1_f32_f32_data[24180];
|
||||||
const uint64_t mul_mat_vec_q4_1_f32_f32_len = 21492;
|
const uint64_t mul_mat_vec_q4_1_f32_f32_len = 24180;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q4_1_f16_f32_data[22284];
|
extern unsigned char mul_mat_vec_q4_1_f16_f32_data[24972];
|
||||||
const uint64_t mul_mat_vec_q4_1_f16_f32_len = 22284;
|
const uint64_t mul_mat_vec_q4_1_f16_f32_len = 24972;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_id_q4_1_f32_data[20972];
|
extern unsigned char mul_mat_vec_id_q4_1_f32_data[24020];
|
||||||
const uint64_t mul_mat_vec_id_q4_1_f32_len = 20972;
|
const uint64_t mul_mat_vec_id_q4_1_f32_len = 24020;
|
||||||
|
|
||||||
extern unsigned char dequant_q4_1_data[5272];
|
extern unsigned char dequant_q4_1_data[5272];
|
||||||
const uint64_t dequant_q4_1_len = 5272;
|
const uint64_t dequant_q4_1_len = 5272;
|
||||||
|
@ -1629,14 +1629,14 @@ const uint64_t get_rows_q4_1_len = 3848;
|
||||||
extern unsigned char get_rows_q4_1_f32_data[3832];
|
extern unsigned char get_rows_q4_1_f32_data[3832];
|
||||||
const uint64_t get_rows_q4_1_f32_len = 3832;
|
const uint64_t get_rows_q4_1_f32_len = 3832;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q5_0_f32_f32_data[26072];
|
extern unsigned char mul_mat_vec_q5_0_f32_f32_data[28760];
|
||||||
const uint64_t mul_mat_vec_q5_0_f32_f32_len = 26072;
|
const uint64_t mul_mat_vec_q5_0_f32_f32_len = 28760;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q5_0_f16_f32_data[26864];
|
extern unsigned char mul_mat_vec_q5_0_f16_f32_data[29552];
|
||||||
const uint64_t mul_mat_vec_q5_0_f16_f32_len = 26864;
|
const uint64_t mul_mat_vec_q5_0_f16_f32_len = 29552;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_id_q5_0_f32_data[25552];
|
extern unsigned char mul_mat_vec_id_q5_0_f32_data[28600];
|
||||||
const uint64_t mul_mat_vec_id_q5_0_f32_len = 25552;
|
const uint64_t mul_mat_vec_id_q5_0_f32_len = 28600;
|
||||||
|
|
||||||
extern unsigned char dequant_q5_0_data[6668];
|
extern unsigned char dequant_q5_0_data[6668];
|
||||||
const uint64_t dequant_q5_0_len = 6668;
|
const uint64_t dequant_q5_0_len = 6668;
|
||||||
|
@ -1647,14 +1647,14 @@ const uint64_t get_rows_q5_0_len = 4292;
|
||||||
extern unsigned char get_rows_q5_0_f32_data[4276];
|
extern unsigned char get_rows_q5_0_f32_data[4276];
|
||||||
const uint64_t get_rows_q5_0_f32_len = 4276;
|
const uint64_t get_rows_q5_0_f32_len = 4276;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q5_1_f32_f32_data[27500];
|
extern unsigned char mul_mat_vec_q5_1_f32_f32_data[30188];
|
||||||
const uint64_t mul_mat_vec_q5_1_f32_f32_len = 27500;
|
const uint64_t mul_mat_vec_q5_1_f32_f32_len = 30188;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q5_1_f16_f32_data[28292];
|
extern unsigned char mul_mat_vec_q5_1_f16_f32_data[30980];
|
||||||
const uint64_t mul_mat_vec_q5_1_f16_f32_len = 28292;
|
const uint64_t mul_mat_vec_q5_1_f16_f32_len = 30980;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_id_q5_1_f32_data[26980];
|
extern unsigned char mul_mat_vec_id_q5_1_f32_data[30028];
|
||||||
const uint64_t mul_mat_vec_id_q5_1_f32_len = 26980;
|
const uint64_t mul_mat_vec_id_q5_1_f32_len = 30028;
|
||||||
|
|
||||||
extern unsigned char dequant_q5_1_data[6564];
|
extern unsigned char dequant_q5_1_data[6564];
|
||||||
const uint64_t dequant_q5_1_len = 6564;
|
const uint64_t dequant_q5_1_len = 6564;
|
||||||
|
@ -1665,14 +1665,14 @@ const uint64_t get_rows_q5_1_len = 4188;
|
||||||
extern unsigned char get_rows_q5_1_f32_data[4172];
|
extern unsigned char get_rows_q5_1_f32_data[4172];
|
||||||
const uint64_t get_rows_q5_1_f32_len = 4172;
|
const uint64_t get_rows_q5_1_f32_len = 4172;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q8_0_f32_f32_data[19612];
|
extern unsigned char mul_mat_vec_q8_0_f32_f32_data[22300];
|
||||||
const uint64_t mul_mat_vec_q8_0_f32_f32_len = 19612;
|
const uint64_t mul_mat_vec_q8_0_f32_f32_len = 22300;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q8_0_f16_f32_data[19820];
|
extern unsigned char mul_mat_vec_q8_0_f16_f32_data[22508];
|
||||||
const uint64_t mul_mat_vec_q8_0_f16_f32_len = 19820;
|
const uint64_t mul_mat_vec_q8_0_f16_f32_len = 22508;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_id_q8_0_f32_data[19108];
|
extern unsigned char mul_mat_vec_id_q8_0_f32_data[22156];
|
||||||
const uint64_t mul_mat_vec_id_q8_0_f32_len = 19108;
|
const uint64_t mul_mat_vec_id_q8_0_f32_len = 22156;
|
||||||
|
|
||||||
extern unsigned char dequant_q8_0_data[4804];
|
extern unsigned char dequant_q8_0_data[4804];
|
||||||
const uint64_t dequant_q8_0_len = 4804;
|
const uint64_t dequant_q8_0_len = 4804;
|
||||||
|
@ -1683,74 +1683,74 @@ const uint64_t get_rows_q8_0_len = 3704;
|
||||||
extern unsigned char get_rows_q8_0_f32_data[3688];
|
extern unsigned char get_rows_q8_0_f32_data[3688];
|
||||||
const uint64_t get_rows_q8_0_f32_len = 3688;
|
const uint64_t get_rows_q8_0_f32_len = 3688;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q2_k_f32_f32_data[17732];
|
extern unsigned char mul_mat_vec_q2_k_f32_f32_data[19580];
|
||||||
const uint64_t mul_mat_vec_q2_k_f32_f32_len = 17732;
|
const uint64_t mul_mat_vec_q2_k_f32_f32_len = 19580;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q2_k_f16_f32_data[18212];
|
extern unsigned char mul_mat_vec_q2_k_f16_f32_data[20060];
|
||||||
const uint64_t mul_mat_vec_q2_k_f16_f32_len = 18212;
|
const uint64_t mul_mat_vec_q2_k_f16_f32_len = 20060;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_id_q2_k_f32_data[17228];
|
extern unsigned char mul_mat_vec_id_q2_k_f32_data[19316];
|
||||||
const uint64_t mul_mat_vec_id_q2_k_f32_len = 17228;
|
const uint64_t mul_mat_vec_id_q2_k_f32_len = 19316;
|
||||||
|
|
||||||
extern unsigned char dequant_q2_k_data[3960];
|
extern unsigned char dequant_q2_k_data[3960];
|
||||||
const uint64_t dequant_q2_k_len = 3960;
|
const uint64_t dequant_q2_k_len = 3960;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q3_k_f32_f32_data[25020];
|
extern unsigned char mul_mat_vec_q3_k_f32_f32_data[26868];
|
||||||
const uint64_t mul_mat_vec_q3_k_f32_f32_len = 25020;
|
const uint64_t mul_mat_vec_q3_k_f32_f32_len = 26868;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q3_k_f16_f32_data[25540];
|
extern unsigned char mul_mat_vec_q3_k_f16_f32_data[27388];
|
||||||
const uint64_t mul_mat_vec_q3_k_f16_f32_len = 25540;
|
const uint64_t mul_mat_vec_q3_k_f16_f32_len = 27388;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_id_q3_k_f32_data[24532];
|
extern unsigned char mul_mat_vec_id_q3_k_f32_data[26604];
|
||||||
const uint64_t mul_mat_vec_id_q3_k_f32_len = 24532;
|
const uint64_t mul_mat_vec_id_q3_k_f32_len = 26604;
|
||||||
|
|
||||||
extern unsigned char dequant_q3_k_data[4828];
|
extern unsigned char dequant_q3_k_data[4828];
|
||||||
const uint64_t dequant_q3_k_len = 4828;
|
const uint64_t dequant_q3_k_len = 4828;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q4_k_f32_f32_data[16620];
|
extern unsigned char mul_mat_vec_q4_k_f32_f32_data[18468];
|
||||||
const uint64_t mul_mat_vec_q4_k_f32_f32_len = 16620;
|
const uint64_t mul_mat_vec_q4_k_f32_f32_len = 18468;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q4_k_f16_f32_data[17132];
|
extern unsigned char mul_mat_vec_q4_k_f16_f32_data[18980];
|
||||||
const uint64_t mul_mat_vec_q4_k_f16_f32_len = 17132;
|
const uint64_t mul_mat_vec_q4_k_f16_f32_len = 18980;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_id_q4_k_f32_data[16100];
|
extern unsigned char mul_mat_vec_id_q4_k_f32_data[18204];
|
||||||
const uint64_t mul_mat_vec_id_q4_k_f32_len = 16100;
|
const uint64_t mul_mat_vec_id_q4_k_f32_len = 18204;
|
||||||
|
|
||||||
extern unsigned char dequant_q4_k_data[5984];
|
extern unsigned char dequant_q4_k_data[5984];
|
||||||
const uint64_t dequant_q4_k_len = 5984;
|
const uint64_t dequant_q4_k_len = 5984;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q5_k_f32_f32_data[18180];
|
extern unsigned char mul_mat_vec_q5_k_f32_f32_data[20028];
|
||||||
const uint64_t mul_mat_vec_q5_k_f32_f32_len = 18180;
|
const uint64_t mul_mat_vec_q5_k_f32_f32_len = 20028;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q5_k_f16_f32_data[18660];
|
extern unsigned char mul_mat_vec_q5_k_f16_f32_data[20508];
|
||||||
const uint64_t mul_mat_vec_q5_k_f16_f32_len = 18660;
|
const uint64_t mul_mat_vec_q5_k_f16_f32_len = 20508;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_id_q5_k_f32_data[17660];
|
extern unsigned char mul_mat_vec_id_q5_k_f32_data[19764];
|
||||||
const uint64_t mul_mat_vec_id_q5_k_f32_len = 17660;
|
const uint64_t mul_mat_vec_id_q5_k_f32_len = 19764;
|
||||||
|
|
||||||
extern unsigned char dequant_q5_k_data[6032];
|
extern unsigned char dequant_q5_k_data[6032];
|
||||||
const uint64_t dequant_q5_k_len = 6032;
|
const uint64_t dequant_q5_k_len = 6032;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q6_k_f32_f32_data[17924];
|
extern unsigned char mul_mat_vec_q6_k_f32_f32_data[19772];
|
||||||
const uint64_t mul_mat_vec_q6_k_f32_f32_len = 17924;
|
const uint64_t mul_mat_vec_q6_k_f32_f32_len = 19772;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_q6_k_f16_f32_data[18444];
|
extern unsigned char mul_mat_vec_q6_k_f16_f32_data[20292];
|
||||||
const uint64_t mul_mat_vec_q6_k_f16_f32_len = 18444;
|
const uint64_t mul_mat_vec_q6_k_f16_f32_len = 20292;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_id_q6_k_f32_data[17404];
|
extern unsigned char mul_mat_vec_id_q6_k_f32_data[19508];
|
||||||
const uint64_t mul_mat_vec_id_q6_k_f32_len = 17404;
|
const uint64_t mul_mat_vec_id_q6_k_f32_len = 19508;
|
||||||
|
|
||||||
extern unsigned char dequant_q6_k_data[4264];
|
extern unsigned char dequant_q6_k_data[4264];
|
||||||
const uint64_t dequant_q6_k_len = 4264;
|
const uint64_t dequant_q6_k_len = 4264;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_iq4_nl_f32_f32_data[20640];
|
extern unsigned char mul_mat_vec_iq4_nl_f32_f32_data[23328];
|
||||||
const uint64_t mul_mat_vec_iq4_nl_f32_f32_len = 20640;
|
const uint64_t mul_mat_vec_iq4_nl_f32_f32_len = 23328;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_iq4_nl_f16_f32_data[21432];
|
extern unsigned char mul_mat_vec_iq4_nl_f16_f32_data[24120];
|
||||||
const uint64_t mul_mat_vec_iq4_nl_f16_f32_len = 21432;
|
const uint64_t mul_mat_vec_iq4_nl_f16_f32_len = 24120;
|
||||||
|
|
||||||
extern unsigned char mul_mat_vec_id_iq4_nl_f32_data[20136];
|
extern unsigned char mul_mat_vec_id_iq4_nl_f32_data[23184];
|
||||||
const uint64_t mul_mat_vec_id_iq4_nl_f32_len = 20136;
|
const uint64_t mul_mat_vec_id_iq4_nl_f32_len = 23184;
|
||||||
|
|
||||||
extern unsigned char dequant_iq4_nl_data[5920];
|
extern unsigned char dequant_iq4_nl_data[5920];
|
||||||
const uint64_t dequant_iq4_nl_len = 5920;
|
const uint64_t dequant_iq4_nl_len = 5920;
|
||||||
|
@ -1776,74 +1776,74 @@ const uint64_t group_norm_f32_len = 3080;
|
||||||
extern unsigned char rms_norm_f32_data[2544];
|
extern unsigned char rms_norm_f32_data[2544];
|
||||||
const uint64_t rms_norm_f32_len = 2544;
|
const uint64_t rms_norm_f32_len = 2544;
|
||||||
|
|
||||||
extern unsigned char cpy_f32_f32_data[4608];
|
extern unsigned char cpy_f32_f32_data[4684];
|
||||||
const uint64_t cpy_f32_f32_len = 4608;
|
const uint64_t cpy_f32_f32_len = 4684;
|
||||||
|
|
||||||
extern unsigned char cpy_f32_f16_data[4660];
|
extern unsigned char cpy_f32_f16_data[4736];
|
||||||
const uint64_t cpy_f32_f16_len = 4660;
|
const uint64_t cpy_f32_f16_len = 4736;
|
||||||
|
|
||||||
extern unsigned char cpy_f16_f16_data[4628];
|
extern unsigned char cpy_f16_f16_data[4704];
|
||||||
const uint64_t cpy_f16_f16_len = 4628;
|
const uint64_t cpy_f16_f16_len = 4704;
|
||||||
|
|
||||||
extern unsigned char contig_cpy_f32_f32_data[2952];
|
extern unsigned char contig_cpy_f32_f32_data[3164];
|
||||||
const uint64_t contig_cpy_f32_f32_len = 2952;
|
const uint64_t contig_cpy_f32_f32_len = 3164;
|
||||||
|
|
||||||
extern unsigned char contig_cpy_f32_f16_data[3068];
|
extern unsigned char contig_cpy_f32_f16_data[3280];
|
||||||
const uint64_t contig_cpy_f32_f16_len = 3068;
|
const uint64_t contig_cpy_f32_f16_len = 3280;
|
||||||
|
|
||||||
extern unsigned char contig_cpy_f16_f16_data[2972];
|
extern unsigned char contig_cpy_f16_f16_data[3184];
|
||||||
const uint64_t contig_cpy_f16_f16_len = 2972;
|
const uint64_t contig_cpy_f16_f16_len = 3184;
|
||||||
|
|
||||||
extern unsigned char add_f32_data[5780];
|
extern unsigned char add_f32_data[5916];
|
||||||
const uint64_t add_f32_len = 5780;
|
const uint64_t add_f32_len = 5916;
|
||||||
|
|
||||||
extern unsigned char add_f16_f32_f16_data[5848];
|
extern unsigned char add_f16_f32_f16_data[5984];
|
||||||
const uint64_t add_f16_f32_f16_len = 5848;
|
const uint64_t add_f16_f32_f16_len = 5984;
|
||||||
|
|
||||||
extern unsigned char acc_f32_data[4888];
|
extern unsigned char acc_f32_data[5100];
|
||||||
const uint64_t acc_f32_len = 4888;
|
const uint64_t acc_f32_len = 5100;
|
||||||
|
|
||||||
extern unsigned char split_k_reduce_data[2764];
|
extern unsigned char split_k_reduce_data[2764];
|
||||||
const uint64_t split_k_reduce_len = 2764;
|
const uint64_t split_k_reduce_len = 2764;
|
||||||
|
|
||||||
extern unsigned char mul_f32_data[5780];
|
extern unsigned char mul_f32_data[5916];
|
||||||
const uint64_t mul_f32_len = 5780;
|
const uint64_t mul_f32_len = 5916;
|
||||||
|
|
||||||
extern unsigned char div_f32_data[5780];
|
extern unsigned char div_f32_data[5916];
|
||||||
const uint64_t div_f32_len = 5780;
|
const uint64_t div_f32_len = 5916;
|
||||||
|
|
||||||
extern unsigned char repeat_f32_data[4308];
|
extern unsigned char repeat_f32_data[4384];
|
||||||
const uint64_t repeat_f32_len = 4308;
|
const uint64_t repeat_f32_len = 4384;
|
||||||
|
|
||||||
extern unsigned char scale_f32_data[2440];
|
extern unsigned char scale_f32_data[2532];
|
||||||
const uint64_t scale_f32_len = 2440;
|
const uint64_t scale_f32_len = 2532;
|
||||||
|
|
||||||
extern unsigned char sqr_f32_data[4628];
|
extern unsigned char sqr_f32_data[4704];
|
||||||
const uint64_t sqr_f32_len = 4628;
|
const uint64_t sqr_f32_len = 4704;
|
||||||
|
|
||||||
extern unsigned char sin_f32_data[4632];
|
extern unsigned char sin_f32_data[4708];
|
||||||
const uint64_t sin_f32_len = 4632;
|
const uint64_t sin_f32_len = 4708;
|
||||||
|
|
||||||
extern unsigned char cos_f32_data[4632];
|
extern unsigned char cos_f32_data[4708];
|
||||||
const uint64_t cos_f32_len = 4632;
|
const uint64_t cos_f32_len = 4708;
|
||||||
|
|
||||||
extern unsigned char clamp_f32_data[4888];
|
extern unsigned char clamp_f32_data[4964];
|
||||||
const uint64_t clamp_f32_len = 4888;
|
const uint64_t clamp_f32_len = 4964;
|
||||||
|
|
||||||
extern unsigned char pad_f32_data[3912];
|
extern unsigned char pad_f32_data[3988];
|
||||||
const uint64_t pad_f32_len = 3912;
|
const uint64_t pad_f32_len = 3988;
|
||||||
|
|
||||||
extern unsigned char concat_f32_data[5316];
|
extern unsigned char concat_f32_data[5452];
|
||||||
const uint64_t concat_f32_len = 5316;
|
const uint64_t concat_f32_len = 5452;
|
||||||
|
|
||||||
extern unsigned char concat_f16_data[5400];
|
extern unsigned char concat_f16_data[5556];
|
||||||
const uint64_t concat_f16_len = 5400;
|
const uint64_t concat_f16_len = 5556;
|
||||||
|
|
||||||
extern unsigned char concat_i32_data[5316];
|
extern unsigned char concat_i32_data[5452];
|
||||||
const uint64_t concat_i32_len = 5316;
|
const uint64_t concat_i32_len = 5452;
|
||||||
|
|
||||||
extern unsigned char upscale_f32_data[2856];
|
extern unsigned char upscale_f32_data[2952];
|
||||||
const uint64_t upscale_f32_len = 2856;
|
const uint64_t upscale_f32_len = 2952;
|
||||||
|
|
||||||
extern unsigned char gelu_f32_data[1700];
|
extern unsigned char gelu_f32_data[1700];
|
||||||
const uint64_t gelu_f32_len = 1700;
|
const uint64_t gelu_f32_len = 1700;
|
||||||
|
@ -1896,14 +1896,14 @@ const uint64_t argsort_f32_len = 4200;
|
||||||
extern unsigned char sum_rows_f32_data[2320];
|
extern unsigned char sum_rows_f32_data[2320];
|
||||||
const uint64_t sum_rows_f32_len = 2320;
|
const uint64_t sum_rows_f32_len = 2320;
|
||||||
|
|
||||||
extern unsigned char im2col_f32_data[3672];
|
extern unsigned char im2col_f32_data[4548];
|
||||||
const uint64_t im2col_f32_len = 3672;
|
const uint64_t im2col_f32_len = 4548;
|
||||||
|
|
||||||
extern unsigned char im2col_f32_f16_data[3732];
|
extern unsigned char im2col_f32_f16_data[4600];
|
||||||
const uint64_t im2col_f32_f16_len = 3732;
|
const uint64_t im2col_f32_f16_len = 4600;
|
||||||
|
|
||||||
extern unsigned char im2col_f32_f16_rte_data[3756];
|
extern unsigned char im2col_f32_f16_rte_data[4624];
|
||||||
const uint64_t im2col_f32_f16_rte_len = 3756;
|
const uint64_t im2col_f32_f16_rte_len = 4624;
|
||||||
|
|
||||||
extern unsigned char timestep_embedding_f32_data[2000];
|
extern unsigned char timestep_embedding_f32_data[2000];
|
||||||
const uint64_t timestep_embedding_f32_len = 2000;
|
const uint64_t timestep_embedding_f32_len = 2000;
|
||||||
|
|
|
@ -145,6 +145,8 @@ class vk_perf_logger;
|
||||||
#endif
|
#endif
|
||||||
static void ggml_vk_destroy_buffer(vk_buffer& buf);
|
static void ggml_vk_destroy_buffer(vk_buffer& buf);
|
||||||
|
|
||||||
|
static constexpr uint32_t mul_mat_vec_max_cols = 8;
|
||||||
|
|
||||||
struct vk_device_struct {
|
struct vk_device_struct {
|
||||||
std::mutex mutex;
|
std::mutex mutex;
|
||||||
|
|
||||||
|
@ -202,8 +204,8 @@ struct vk_device_struct {
|
||||||
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
|
vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT];
|
||||||
|
|
||||||
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
|
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
|
||||||
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT];
|
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
||||||
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT];
|
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_COUNT][mul_mat_vec_max_cols];
|
||||||
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
|
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
|
||||||
|
|
||||||
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
|
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32;
|
||||||
|
@ -411,7 +413,7 @@ struct vk_op_unary_push_constants {
|
||||||
uint32_t ne;
|
uint32_t ne;
|
||||||
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
|
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
|
||||||
uint32_t d_offset;
|
uint32_t misalign_offsets;
|
||||||
float param1; float param2;
|
float param1; float param2;
|
||||||
uint32_t ne0_012mp; uint32_t ne0_012L;
|
uint32_t ne0_012mp; uint32_t ne0_012L;
|
||||||
uint32_t ne0_01mp; uint32_t ne0_01L;
|
uint32_t ne0_01mp; uint32_t ne0_01L;
|
||||||
|
@ -459,7 +461,7 @@ struct vk_op_binary_push_constants {
|
||||||
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
|
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
|
||||||
uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
|
uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23;
|
||||||
uint32_t d_offset;
|
uint32_t misalign_offsets;
|
||||||
float param1; float param2; int32_t param3;
|
float param1; float param2; int32_t param3;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -546,7 +548,7 @@ struct vk_staging_memcpy {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct vk_op_upscale_push_constants {
|
struct vk_op_upscale_push_constants {
|
||||||
uint32_t ne; uint32_t d_offset;
|
uint32_t ne; uint32_t a_offset; uint32_t d_offset;
|
||||||
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||||
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
|
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13;
|
||||||
float sf0; float sf1; float sf2; float sf3;
|
float sf0; float sf1; float sf2; float sf3;
|
||||||
|
@ -1404,10 +1406,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
// spec constants and tile sizes for non-quant matmul/matmul_id
|
// spec constants and tile sizes for non-quant matmul/matmul_id
|
||||||
l_warptile = { 256, 128, 256, 64 };
|
l_warptile = { 256, 128, 256, 64 };
|
||||||
m_warptile = { 256, 128, 128, 64 };
|
m_warptile = { 256, 128, 128, 64 };
|
||||||
s_warptile = { 128, 32, 16, 64 };
|
s_warptile = { 128, 64, 64, 64 };
|
||||||
l_wg_denoms = {128, 256, 1 };
|
l_wg_denoms = {128, 256, 1 };
|
||||||
m_wg_denoms = {128, 128, 1 };
|
m_wg_denoms = {128, 128, 1 };
|
||||||
s_wg_denoms = { 32, 16, 1 };
|
s_wg_denoms = { 64, 64, 1 };
|
||||||
|
|
||||||
// spec constants and tile sizes for quant matmul (non-Qi_K)
|
// spec constants and tile sizes for quant matmul (non-Qi_K)
|
||||||
l_warptile_mmq = { 256, 128, 256, 64 };
|
l_warptile_mmq = { 256, 128, 256, 64 };
|
||||||
|
@ -1866,33 +1868,35 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
} else if (device->vendor_id == VK_VENDOR_ID_INTEL)
|
} else if (device->vendor_id == VK_VENDOR_ID_INTEL)
|
||||||
rm_stdq = 2;
|
rm_stdq = 2;
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f32_f32", mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f32_f32", mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32_"+std::to_string(i+1), mul_mat_vec_f32_f32_f32_len, mul_mat_vec_f32_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f32_f32", mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32_"+std::to_string(i+1), mul_mat_vec_f16_f32_f32_len, mul_mat_vec_f16_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f32_f32", mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f32_f32_len, mul_mat_vec_q4_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f32_f32", mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f32_f32_len, mul_mat_vec_q4_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f32_f32", mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f32_f32_len, mul_mat_vec_q5_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f32_f32", mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f32_f32_len, mul_mat_vec_q5_1_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f32_f32", mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f32_f32_len, mul_mat_vec_q8_0_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f32_f32", mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f32_f32_len, mul_mat_vec_q2_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f32_f32", mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f32_f32_len, mul_mat_vec_q3_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f32_f32", mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f32_f32_len, mul_mat_vec_q4_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f32_f32", mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f32_f32_len, mul_mat_vec_q5_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f32_f32", mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f32_f32_len, mul_mat_vec_q6_k_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ], "mul_mat_vec_f32_f16_f32", mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ], "mul_mat_vec_f16_f16_f32", mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32_"+std::to_string(i+1), mul_mat_vec_f16_f16_f32_len, mul_mat_vec_f16_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0], "mul_mat_vec_q4_0_f16_f32", mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_0_f16_f32_len, mul_mat_vec_q4_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1], "mul_mat_vec_q4_1_f16_f32", mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_1_f16_f32_len, mul_mat_vec_q4_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0], "mul_mat_vec_q5_0_f16_f32", mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_0_f16_f32_len, mul_mat_vec_q5_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1], "mul_mat_vec_q5_1_f16_f32", mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_1_f16_f32_len, mul_mat_vec_q5_1_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0], "mul_mat_vec_q8_0_f16_f32", mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32_"+std::to_string(i+1), mul_mat_vec_q8_0_f16_f32_len, mul_mat_vec_q8_0_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K], "mul_mat_vec_q2_k_f16_f32", mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q2_k_f16_f32_len, mul_mat_vec_q2_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K], "mul_mat_vec_q3_k_f16_f32", mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q3_k_f16_f32_len, mul_mat_vec_q3_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K], "mul_mat_vec_q4_k_f16_f32", mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q4_k_f16_f32_len, mul_mat_vec_q4_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K], "mul_mat_vec_q5_k_f16_f32", mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q5_k_f16_f32_len, mul_mat_vec_q5_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K], "mul_mat_vec_q6_k_f16_f32", mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32_"+std::to_string(i+1), mul_mat_vec_q6_k_f16_f32_len, mul_mat_vec_q6_k_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_iq4_nl_f16_f32", mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
|
||||||
|
}
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
|
||||||
|
@ -2017,11 +2021,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||||
if (device->float_controls_rte_fp16) {
|
if (device->float_controls_rte_fp16) {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||||
} else {
|
} else {
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
|
||||||
|
@ -2892,9 +2896,10 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
|
||||||
return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
|
return ctx->device->fp16 ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc;
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
|
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols) {
|
||||||
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
|
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()");
|
||||||
GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16);
|
GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols);
|
||||||
|
|
||||||
switch (a_type) {
|
switch (a_type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
|
@ -2915,7 +2920,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type];
|
return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[a_type][num_cols-1];
|
||||||
}
|
}
|
||||||
|
|
||||||
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
|
static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) {
|
||||||
|
@ -3925,8 +3930,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||||
const uint64_t ne12 = src1->ne[2];
|
const uint64_t ne12 = src1->ne[2];
|
||||||
const uint64_t ne13 = src1->ne[3];
|
const uint64_t ne13 = src1->ne[3];
|
||||||
|
|
||||||
GGML_ASSERT(ne11 == 1);
|
|
||||||
|
|
||||||
const uint64_t ne20 = dst->ne[0];
|
const uint64_t ne20 = dst->ne[0];
|
||||||
const uint64_t ne21 = dst->ne[1];
|
const uint64_t ne21 = dst->ne[1];
|
||||||
const uint64_t ne22 = dst->ne[2];
|
const uint64_t ne22 = dst->ne[2];
|
||||||
|
@ -3935,6 +3938,11 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||||
const uint64_t r2 = ne12 / ne02;
|
const uint64_t r2 = ne12 / ne02;
|
||||||
const uint64_t r3 = ne13 / ne03;
|
const uint64_t r3 = ne13 / ne03;
|
||||||
|
|
||||||
|
// batch_n indicates that we need to compute a few vector results, and this assumes
|
||||||
|
// ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides.
|
||||||
|
GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1);
|
||||||
|
bool batch_n = ne11 > 1;
|
||||||
|
|
||||||
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context;
|
||||||
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
|
ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context;
|
||||||
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
|
ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context;
|
||||||
|
@ -3985,7 +3993,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||||
} else {
|
} else {
|
||||||
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
|
||||||
}
|
}
|
||||||
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type);
|
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11);
|
||||||
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
|
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
|
||||||
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
|
||||||
GGML_ASSERT(dmmv != nullptr);
|
GGML_ASSERT(dmmv != nullptr);
|
||||||
|
@ -4057,8 +4065,10 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||||
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, { d_Qy, qy_buf_offset, VK_WHOLE_SIZE }, { d_Y, 0, VK_WHOLE_SIZE });
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t stride_batch_x = ne00*ne01;
|
// For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride
|
||||||
uint32_t stride_batch_y = ne10*ne11;
|
uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01;
|
||||||
|
uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11);
|
||||||
|
uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21);
|
||||||
|
|
||||||
if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
|
if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) {
|
||||||
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
|
stride_batch_x = src0->nb[0] / ggml_type_size(src0->type);
|
||||||
|
@ -4081,7 +4091,7 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context&
|
||||||
// compute
|
// compute
|
||||||
const vk_mat_vec_push_constants pc = {
|
const vk_mat_vec_push_constants pc = {
|
||||||
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
(uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01,
|
||||||
stride_batch_x, stride_batch_y, (uint32_t)(ne20*ne21),
|
stride_batch_x, stride_batch_y, stride_batch_d,
|
||||||
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
|
(uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3,
|
||||||
};
|
};
|
||||||
ggml_vk_sync_buffers(subctx);
|
ggml_vk_sync_buffers(subctx);
|
||||||
|
@ -4261,7 +4271,10 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
||||||
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
|
} else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 &&
|
||||||
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
|
!ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
|
||||||
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
|
ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun);
|
||||||
} else if (dst->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
// mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four)
|
||||||
|
// when ne12 and ne13 are one.
|
||||||
|
} else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) &&
|
||||||
|
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
|
||||||
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
||||||
} else {
|
} else {
|
||||||
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, dryrun);
|
||||||
|
@ -5076,6 +5089,57 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_tensor * t)
|
||||||
|
{
|
||||||
|
return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
|
||||||
|
GGML_UNUSED(p);
|
||||||
|
GGML_UNUSED(src0);
|
||||||
|
GGML_UNUSED(src1);
|
||||||
|
GGML_UNUSED(src2);
|
||||||
|
GGML_UNUSED(dst);
|
||||||
|
static_assert(!std::is_const<T>::value, "unexpected type");
|
||||||
|
GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0);
|
||||||
|
GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0);
|
||||||
|
GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0);
|
||||||
|
GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
|
||||||
|
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
|
||||||
|
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
|
||||||
|
|
||||||
|
p.misalign_offsets = (a_offset << 16) | d_offset;
|
||||||
|
|
||||||
|
GGML_UNUSED(src1);
|
||||||
|
GGML_UNUSED(src2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
|
||||||
|
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
|
||||||
|
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
|
||||||
|
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0));
|
||||||
|
|
||||||
|
p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset;
|
||||||
|
|
||||||
|
GGML_UNUSED(src2);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
|
||||||
|
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
|
||||||
|
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
|
||||||
|
|
||||||
|
p.a_offset = a_offset;
|
||||||
|
p.d_offset = d_offset;
|
||||||
|
|
||||||
|
GGML_UNUSED(src1);
|
||||||
|
GGML_UNUSED(src2);
|
||||||
|
}
|
||||||
|
|
||||||
template<typename PC>
|
template<typename PC>
|
||||||
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
|
static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) {
|
||||||
VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3];
|
||||||
|
@ -5179,8 +5243,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_ASSERT(d_D != nullptr);
|
GGML_ASSERT(d_D != nullptr);
|
||||||
uint64_t d_buf_offset = ((vk_tensor_offset(dst) + dst->view_offs) / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment;
|
uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs;
|
||||||
GGML_ASSERT(d_buf_offset == vk_tensor_offset(dst) || op == GGML_OP_CPY); // NOLINT
|
|
||||||
if(!src0_uma) {
|
if(!src0_uma) {
|
||||||
d_X = src0_buf_ctx->dev_buffer;
|
d_X = src0_buf_ctx->dev_buffer;
|
||||||
x_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
|
x_buf_offset = vk_tensor_offset(src0) + src0->view_offs;
|
||||||
|
@ -5196,6 +5259,12 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
z_buf_offset = vk_tensor_offset(src2) + src2->view_offs;
|
z_buf_offset = vk_tensor_offset(src2) + src2->view_offs;
|
||||||
GGML_ASSERT(d_Z != nullptr);
|
GGML_ASSERT(d_Z != nullptr);
|
||||||
}
|
}
|
||||||
|
// Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets.
|
||||||
|
init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst);
|
||||||
|
x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
||||||
|
y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
||||||
|
z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
||||||
|
d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1);
|
||||||
|
|
||||||
if (op_supports_incontiguous) {
|
if (op_supports_incontiguous) {
|
||||||
x_sz = ggml_nbytes(src0);
|
x_sz = ggml_nbytes(src0);
|
||||||
|
@ -5383,7 +5452,6 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||||
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
const uint32_t src1_type_size = ggml_type_size(src1->type);
|
||||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||||
const uint32_t d_offset = ((vk_tensor_offset(dst) + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
|
|
||||||
|
|
||||||
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
|
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
|
||||||
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
|
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
|
||||||
|
@ -5395,7 +5463,7 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
|
||||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
|
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size,
|
||||||
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
|
||||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size,
|
||||||
d_offset,
|
0,
|
||||||
0.0f, 0.0f, offset,
|
0.0f, 0.0f, offset,
|
||||||
}, dryrun);
|
}, dryrun);
|
||||||
}
|
}
|
||||||
|
@ -5599,7 +5667,7 @@ static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, c
|
||||||
const float sf3 = (float)dst->ne[3] / src0->ne[3];
|
const float sf3 = (float)dst->ne[3] / src0->ne[3];
|
||||||
|
|
||||||
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
|
ggml_vk_op_f32<vk_op_upscale_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, {
|
||||||
(uint32_t)ggml_nelements(dst), 0,
|
(uint32_t)ggml_nelements(dst), 0, 0,
|
||||||
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
(uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||||
(uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
|
(uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3],
|
||||||
sf0, sf1, sf2, sf3,
|
sf0, sf1, sf2, sf3,
|
||||||
|
@ -5709,13 +5777,12 @@ static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||||
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
const uint32_t src0_type_size = ggml_type_size(src0->type);
|
||||||
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
const uint32_t dst_type_size = ggml_type_size(dst->type);
|
||||||
const uint32_t d_offset = ((vk_tensor_offset(dst) + dst->view_offs) % ctx->device->properties.limits.minStorageBufferOffsetAlignment) / dst_type_size;
|
|
||||||
|
|
||||||
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
|
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
|
||||||
(uint32_t)ggml_nelements(src0),
|
(uint32_t)ggml_nelements(src0),
|
||||||
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
|
||||||
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
|
||||||
d_offset,
|
0,
|
||||||
0.0f, 0.0f,
|
0.0f, 0.0f,
|
||||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
||||||
}, dryrun);
|
}, dryrun);
|
||||||
|
|
|
@ -21,9 +21,9 @@ void main() {
|
||||||
get_indices(idx, i00, i01, i02, i03);
|
get_indices(idx, i00, i01, i02, i03);
|
||||||
|
|
||||||
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
|
if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) {
|
||||||
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
|
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11]));
|
||||||
} else {
|
} else {
|
||||||
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]));
|
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ void main() {
|
||||||
uint i00, i01, i02, i03;
|
uint i00, i01, i02, i03;
|
||||||
get_indices(idx, i00, i01, i02, i03);
|
get_indices(idx, i00, i01, i02, i03);
|
||||||
|
|
||||||
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
|
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
|
||||||
|
|
||||||
idx += num_threads;
|
idx += num_threads;
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,6 @@ void main() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
|
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
|
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val));
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,12 +30,12 @@ void main() {
|
||||||
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
|
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
|
||||||
|
|
||||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
||||||
data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : data_b[src1_idx]);
|
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]);
|
||||||
#else
|
#else
|
||||||
if (is_src0) {
|
if (is_src0) {
|
||||||
data_d[p.d_offset + dst_idx] = data_a[src0_idx];
|
data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx];
|
||||||
} else {
|
} else {
|
||||||
data_d[p.d_offset + dst_idx] = data_b[src1_idx];
|
data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx];
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,9 +19,9 @@ void main() {
|
||||||
if (idx + (num_iter-1)*num_threads < p.ne) {
|
if (idx + (num_iter-1)*num_threads < p.ne) {
|
||||||
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
|
||||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
||||||
data_d[p.d_offset + idx] = D_TYPE(data_a[idx]);
|
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
|
||||||
#else
|
#else
|
||||||
data_d[p.d_offset + idx] = data_a[idx];
|
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
|
||||||
#endif
|
#endif
|
||||||
idx += num_threads;
|
idx += num_threads;
|
||||||
}
|
}
|
||||||
|
@ -32,9 +32,9 @@ void main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
||||||
data_d[p.d_offset + idx] = D_TYPE(data_a[idx]);
|
data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]);
|
||||||
#else
|
#else
|
||||||
data_d[p.d_offset + idx] = data_a[idx];
|
data_d[get_doffset() + idx] = data_a[get_aoffset() + idx];
|
||||||
#endif
|
#endif
|
||||||
idx += num_threads;
|
idx += num_threads;
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,8 +13,8 @@ void main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
||||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx(idx)]);
|
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||||
#else
|
#else
|
||||||
data_d[p.d_offset + dst_idx(idx)] = data_a[src0_idx(idx)];
|
data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)];
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,6 @@ void main() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
|
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(cos(val));
|
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val));
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,7 @@ void main() {
|
||||||
uint i00, i01, i02, i03;
|
uint i00, i01, i02, i03;
|
||||||
get_indices(idx, i00, i01, i02, i03);
|
get_indices(idx, i00, i01, i02, i03);
|
||||||
|
|
||||||
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
|
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
|
||||||
|
|
||||||
idx += num_threads;
|
idx += num_threads;
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ layout (push_constant) uniform parameter
|
||||||
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
|
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
|
||||||
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
|
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
|
||||||
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
|
uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23;
|
||||||
uint d_offset;
|
uint misalign_offsets;
|
||||||
float param1; float param2; int param3;
|
float param1; float param2; int param3;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
|
@ -22,6 +22,10 @@ uint get_idx() {
|
||||||
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||||
|
uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; }
|
||||||
|
uint get_doffset() { return p.misalign_offsets & 0xFF; }
|
||||||
|
|
||||||
// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
|
// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1
|
||||||
uint fastmod(uint a, uint b) {
|
uint fastmod(uint a, uint b) {
|
||||||
if ((b & (b-1)) == 0) {
|
if ((b & (b-1)) == 0) {
|
||||||
|
|
|
@ -6,7 +6,7 @@ layout (push_constant) uniform parameter
|
||||||
uint ne;
|
uint ne;
|
||||||
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
|
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
|
||||||
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
|
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
|
||||||
uint d_offset;
|
uint misalign_offsets;
|
||||||
float param1; float param2;
|
float param1; float param2;
|
||||||
|
|
||||||
uint ne0_012mp; uint ne0_012L;
|
uint ne0_012mp; uint ne0_012L;
|
||||||
|
@ -24,6 +24,9 @@ uint get_idx() {
|
||||||
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
uint get_aoffset() { return p.misalign_offsets >> 16; }
|
||||||
|
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
|
||||||
|
|
||||||
// see init_fastdiv_values in ggml-vulkan.cpp
|
// see init_fastdiv_values in ggml-vulkan.cpp
|
||||||
uint fastdiv(uint n, uint mp, uint L) {
|
uint fastdiv(uint n, uint mp, uint L) {
|
||||||
uint msbs, lsbs;
|
uint msbs, lsbs;
|
||||||
|
|
|
@ -15,10 +15,10 @@ void main() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
|
const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
|
||||||
|
|
||||||
const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
|
||||||
const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
|
||||||
|
|
||||||
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
#ifndef OPTIMIZATION_ERROR_WORKAROUND
|
||||||
data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
|
data_d[d_offset + i00] = D_TYPE(data_a[a_offset + i00]);
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
#extension GL_EXT_shader_16bit_storage : require
|
#extension GL_EXT_shader_16bit_storage : require
|
||||||
#extension GL_EXT_spirv_intrinsics: enable
|
#extension GL_EXT_spirv_intrinsics: enable
|
||||||
|
#extension GL_EXT_control_flow_attributes : require
|
||||||
|
|
||||||
#if RTE16
|
#if RTE16
|
||||||
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
|
spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
|
||||||
|
@ -23,40 +24,64 @@ layout (push_constant) uniform parameter
|
||||||
|
|
||||||
#include "types.comp"
|
#include "types.comp"
|
||||||
|
|
||||||
#define BLOCK_SIZE 256
|
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||||
|
|
||||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
const uint NUM_ITER = 512 / BLOCK_SIZE;
|
||||||
|
|
||||||
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint i = gl_GlobalInvocationID.x;
|
const uint gidx = gl_GlobalInvocationID.x;
|
||||||
if (i >= p.pelements) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
|
|
||||||
const uint kx = i / ksize;
|
|
||||||
const uint kd = kx * ksize;
|
|
||||||
const uint ky = (i - kd) / p.OW;
|
|
||||||
const uint ix = i % p.OW;
|
|
||||||
|
|
||||||
const uint oh = gl_GlobalInvocationID.y;
|
const uint oh = gl_GlobalInvocationID.y;
|
||||||
const uint batch = gl_GlobalInvocationID.z / p.IC;
|
const uint batch = gl_GlobalInvocationID.z / p.IC;
|
||||||
const uint ic = gl_GlobalInvocationID.z % p.IC;
|
const uint ic = gl_GlobalInvocationID.z % p.IC;
|
||||||
|
|
||||||
const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
|
A_TYPE values[NUM_ITER];
|
||||||
const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
|
uint offset_dst[NUM_ITER];
|
||||||
|
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
|
||||||
const uint offset_dst =
|
values[idx] = A_TYPE(0);
|
||||||
((batch * p.OH + oh) * p.OW + ix) * p.CHW +
|
|
||||||
(ic * (p.KW * p.KH) + ky * p.KW + kx);
|
|
||||||
|
|
||||||
if (iih < 0 || iih >= p.IH || iiw < 0 || iiw >= p.IW) {
|
|
||||||
data_d[offset_dst] = D_TYPE(0.0f);
|
|
||||||
} else {
|
|
||||||
const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
|
|
||||||
data_d[offset_dst] = D_TYPE(data_a[offset_src + iih * p.IW + iiw]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
|
||||||
|
|
||||||
|
const uint i = gidx * NUM_ITER + idx;
|
||||||
|
|
||||||
|
const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1);
|
||||||
|
const uint kx = i / ksize;
|
||||||
|
const uint kd = kx * ksize;
|
||||||
|
const uint ky = (i - kd) / p.OW;
|
||||||
|
const uint ix = i % p.OW;
|
||||||
|
|
||||||
|
const uint iiw = ix * p.s0 + kx * p.d0 - p.p0;
|
||||||
|
const uint iih = oh * p.s1 + ky * p.d1 - p.p1;
|
||||||
|
|
||||||
|
offset_dst[idx] =
|
||||||
|
((batch * p.OH + oh) * p.OW + ix) * p.CHW +
|
||||||
|
(ic * (p.KW * p.KH) + ky * p.KW + kx);
|
||||||
|
|
||||||
|
if (i >= p.pelements) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (iih < p.IH && iiw < p.IW) {
|
||||||
|
const uint offset_src = ic * p.offset_delta + batch * p.batch_offset;
|
||||||
|
values[idx] = data_a[offset_src + iih * p.IW + iiw];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
|
||||||
|
|
||||||
|
const uint i = gidx * NUM_ITER + idx;
|
||||||
|
|
||||||
|
if (i >= p.pelements) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
data_d[offset_dst[idx]] = D_TYPE(values[idx]);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -20,7 +20,7 @@ void main() {
|
||||||
uint i00, i01, i02, i03;
|
uint i00, i01, i02, i03;
|
||||||
get_indices(idx, i00, i01, i02, i03);
|
get_indices(idx, i00, i01, i02, i03);
|
||||||
|
|
||||||
data_d[p.d_offset + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[src1_idx(i00, i01, i02, i03)]));
|
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
|
||||||
|
|
||||||
idx += num_threads;
|
idx += num_threads;
|
||||||
}
|
}
|
||||||
|
|
|
@ -9,9 +9,6 @@
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
|
||||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
|
||||||
|
|
||||||
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
|
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
|
||||||
#define K_PER_ITER 8
|
#define K_PER_ITER 8
|
||||||
#else
|
#else
|
||||||
|
@ -21,70 +18,70 @@ layout (constant_id = 1) const uint NUM_ROWS = 1;
|
||||||
|
|
||||||
uint a_offset, b_offset, d_offset, y_offset;
|
uint a_offset, b_offset, d_offset, y_offset;
|
||||||
|
|
||||||
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
|
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
|
||||||
|
|
||||||
void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
|
|
||||||
{
|
{
|
||||||
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
|
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
|
||||||
const uint iybs = col - col%QUANT_K; // y block start index
|
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
|
||||||
|
const uint iybs = col - col%QUANT_K; // y block start index
|
||||||
|
|
||||||
#if K_PER_ITER == 8
|
#if K_PER_ITER == 8
|
||||||
#if QUANT_R == 2
|
#if QUANT_R == 2
|
||||||
const B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4];
|
const B_TYPE_VEC4 bv02 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4];
|
||||||
const B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4];
|
const B_TYPE_VEC4 bv13 = data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4];
|
||||||
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
|
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
|
||||||
const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
|
const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
|
||||||
#else
|
#else
|
||||||
const vec4 bv0 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4]);
|
const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
|
||||||
const vec4 bv1 = vec4(data_b_v4[(b_offset + iybs + iqs) / 4 + 1]);
|
const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
|
||||||
#endif
|
#endif
|
||||||
#else
|
#else
|
||||||
// Check if the second of the pair of elements is OOB, and don't fetch B or
|
// Check if the second of the pair of elements is OOB, and don't fetch B or
|
||||||
// accumulate it. We still fetch a pair of elements for A, which is fine for
|
// accumulate it. We still fetch a pair of elements for A, which is fine for
|
||||||
// quantized formats since they'll be within the same block. We should
|
// quantized formats since they'll be within the same block. We should
|
||||||
// probably skip fetching the second element for F16/F32, but as of now we
|
// probably skip fetching the second element for F16/F32, but as of now we
|
||||||
// still do.
|
// still do.
|
||||||
const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
|
const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
|
||||||
|
|
||||||
FLOAT_TYPE b0 = 0, b1 = 0;
|
FLOAT_TYPE b0 = 0, b1 = 0;
|
||||||
b0 = FLOAT_TYPE(data_b[b_offset + iybs + iqs]);
|
b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]);
|
||||||
if (!OOB) {
|
if (!OOB) {
|
||||||
b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
|
b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
uint ibi = first_row*p.ncols;
|
uint ibi = first_row*p.ncols;
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
const uint ib = (ibi + col)/QUANT_K; // block index
|
const uint ib = (ibi + col)/QUANT_K; // block index
|
||||||
ibi += p.ncols;
|
ibi += p.ncols;
|
||||||
|
|
||||||
#if K_PER_ITER == 8
|
#if K_PER_ITER == 8
|
||||||
vec4 v = dequantize4(ib, iqs, a_offset);
|
vec4 v = dequantize4(ib, iqs, a_offset);
|
||||||
vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
|
vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
|
||||||
|
|
||||||
const vec2 dm = get_dm(ib, a_offset);
|
const vec2 dm = get_dm(ib, a_offset);
|
||||||
if (dm.y != 0) { // quant has min component
|
if (dm.y != 0) { // quant has min component
|
||||||
v = v * dm.x + dm.y;
|
v = v * dm.x + dm.y;
|
||||||
v2 = v2 * dm.x + dm.y;
|
v2 = v2 * dm.x + dm.y;
|
||||||
}
|
}
|
||||||
|
|
||||||
// matrix multiplication
|
// matrix multiplication
|
||||||
FLOAT_TYPE rowtmp = dot(bv0, v);
|
FLOAT_TYPE rowtmp = dot(bv0, v);
|
||||||
rowtmp += dot(bv1, v2);
|
rowtmp += dot(bv1, v2);
|
||||||
|
|
||||||
if (dm.y == 0)
|
if (dm.y == 0)
|
||||||
rowtmp *= dm.x;
|
rowtmp *= dm.x;
|
||||||
|
|
||||||
temp[n] += rowtmp;
|
temp[j][n] += rowtmp;
|
||||||
#else
|
#else
|
||||||
const vec2 v = dequantize(ib, iqs, a_offset);
|
const vec2 v = dequantize(ib, iqs, a_offset);
|
||||||
|
|
||||||
// matrix multiplication
|
// matrix multiplication
|
||||||
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
|
temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]);
|
||||||
if (!OOB) {
|
if (!OOB) {
|
||||||
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
|
temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]);
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,10 +93,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
|
|
||||||
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
||||||
|
|
||||||
FLOAT_TYPE temp[NUM_ROWS];
|
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||||
|
|
||||||
for (uint i = 0; i < NUM_ROWS; ++i) {
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
temp[i] = FLOAT_TYPE(0);
|
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||||
|
temp[j][i] = FLOAT_TYPE(0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
|
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
|
||||||
|
@ -131,24 +130,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
i++;
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
tmpsh[n][tid] = temp[n];
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
|
||||||
if (tid < s) {
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
tmpsh[n][tid] += tmpsh[n][tid + s];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
}
|
|
||||||
if (tid == 0) {
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
|
|
@ -83,3 +83,36 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
|
||||||
batch_idx * p.batch_stride_d;
|
batch_idx * p.batch_stride_d;
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||||
|
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
||||||
|
layout (constant_id = 2) const uint NUM_COLS = 1;
|
||||||
|
|
||||||
|
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
|
||||||
|
|
||||||
|
void reduce_result(const in FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
|
||||||
|
// sum up partial sums and write back result
|
||||||
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
|
tmpsh[j][n][tid] = temp[j][n];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
|
tmpsh[j][n][tid] += tmpsh[j][n][tid + s];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
if (tid == 0) {
|
||||||
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
|
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -5,11 +5,6 @@
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
|
||||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
|
||||||
|
|
||||||
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
|
|
||||||
|
|
||||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
uint a_offset, b_offset, d_offset;
|
uint a_offset, b_offset, d_offset;
|
||||||
get_offsets(a_offset, b_offset, d_offset);
|
get_offsets(a_offset, b_offset, d_offset);
|
||||||
|
@ -32,24 +27,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
const uint s_offset = 8*v_im;
|
const uint s_offset = 8*v_im;
|
||||||
const uint y_offset = 128*v_im + l0;
|
const uint y_offset = 128*v_im + l0;
|
||||||
|
|
||||||
FLOAT_TYPE temp[NUM_ROWS];
|
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||||
|
|
||||||
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
temp[i] = FLOAT_TYPE(0);
|
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||||
|
temp[j][i] = FLOAT_TYPE(0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
|
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
|
||||||
const uint y_idx = i * QUANT_K + y_offset;
|
const uint y_idx = i * QUANT_K + y_offset;
|
||||||
|
|
||||||
B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
|
|
||||||
B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
|
|
||||||
B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
|
|
||||||
B_TYPE_VEC2 b48 = data_b_v2[(b_offset + y_idx) / 2 + 24];
|
|
||||||
B_TYPE_VEC2 b64 = data_b_v2[(b_offset + y_idx) / 2 + 32];
|
|
||||||
B_TYPE_VEC2 b80 = data_b_v2[(b_offset + y_idx) / 2 + 40];
|
|
||||||
B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
|
|
||||||
B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
|
|
||||||
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||||
f16vec2 d = data_a[ib0 + i].d;
|
f16vec2 d = data_a[ib0 + i].d;
|
||||||
|
@ -74,48 +62,42 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
uvec2 qs0 = uvec2(unpack8(qs0_u16));
|
uvec2 qs0 = uvec2(unpack8(qs0_u16));
|
||||||
uvec2 qs16 = uvec2(unpack8(qs16_u16));
|
uvec2 qs16 = uvec2(unpack8(qs16_u16));
|
||||||
|
|
||||||
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
|
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
|
||||||
[[unroll]] for (int l = 0; l < 2; ++l) {
|
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
|
||||||
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
|
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
|
||||||
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
|
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
|
||||||
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
|
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
|
||||||
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
|
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
|
||||||
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
|
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
|
||||||
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
|
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
|
||||||
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
|
|
||||||
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
|
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
|
||||||
sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
|
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
|
||||||
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
|
[[unroll]] for (int l = 0; l < 2; ++l) {
|
||||||
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
|
sum1 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_lo4[0]) * FLOAT_TYPE((qs0[l] >> 0) & 3),
|
||||||
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
|
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_lo4[1]) * FLOAT_TYPE((qs16[l] >> 0) & 3),
|
||||||
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
|
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_lo4[2]) * FLOAT_TYPE((qs0[l] >> 2) & 3),
|
||||||
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
|
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_lo4[3]) * FLOAT_TYPE((qs16[l] >> 2) & 3),
|
||||||
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
|
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_lo4[0]) * FLOAT_TYPE((qs0[l] >> 4) & 3),
|
||||||
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
|
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_lo4[1]) * FLOAT_TYPE((qs16[l] >> 4) & 3),
|
||||||
|
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_lo4[2]) * FLOAT_TYPE((qs0[l] >> 6) & 3),
|
||||||
|
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_lo4[3]) * FLOAT_TYPE((qs16[l] >> 6) & 3), sum1))))))));
|
||||||
|
sum2 = fma(FLOAT_TYPE(b0[l]), FLOAT_TYPE(s0_hi4[0]),
|
||||||
|
fma(FLOAT_TYPE(b16[l]), FLOAT_TYPE(s0_hi4[1]),
|
||||||
|
fma(FLOAT_TYPE(b32[l]), FLOAT_TYPE(s0_hi4[2]),
|
||||||
|
fma(FLOAT_TYPE(b48[l]), FLOAT_TYPE(s0_hi4[3]),
|
||||||
|
fma(FLOAT_TYPE(b64[l]), FLOAT_TYPE(s4_hi4[0]),
|
||||||
|
fma(FLOAT_TYPE(b80[l]), FLOAT_TYPE(s4_hi4[1]),
|
||||||
|
fma(FLOAT_TYPE(b96[l]), FLOAT_TYPE(s4_hi4[2]),
|
||||||
|
fma(FLOAT_TYPE(b112[l]), FLOAT_TYPE(s4_hi4[3]), sum2))))))));
|
||||||
|
}
|
||||||
|
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
|
||||||
}
|
}
|
||||||
temp[n] = fma(dall, sum1, fma(-dmin, sum2, temp[n]));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
tmpsh[n][tid] = temp[n];
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
|
||||||
if (tid < s) {
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
tmpsh[n][tid] += tmpsh[n][tid + s];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
}
|
|
||||||
if (tid == 0) {
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
|
|
@ -5,11 +5,6 @@
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
|
||||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
|
||||||
|
|
||||||
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
|
|
||||||
|
|
||||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
uint a_offset, b_offset, d_offset;
|
uint a_offset, b_offset, d_offset;
|
||||||
get_offsets(a_offset, b_offset, d_offset);
|
get_offsets(a_offset, b_offset, d_offset);
|
||||||
|
@ -33,10 +28,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
const uint q_offset = 32*v_im + l0;
|
const uint q_offset = 32*v_im + l0;
|
||||||
const uint y_offset = 128*v_im + l0;
|
const uint y_offset = 128*v_im + l0;
|
||||||
|
|
||||||
FLOAT_TYPE temp[NUM_ROWS];
|
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||||
|
|
||||||
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
temp[i] = FLOAT_TYPE(0);
|
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||||
|
temp[j][i] = FLOAT_TYPE(0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const uint s_shift = 4 * v_im;
|
const uint s_shift = 4 * v_im;
|
||||||
|
@ -44,15 +41,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
|
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
|
||||||
const uint y_idx = i * QUANT_K + y_offset;
|
const uint y_idx = i * QUANT_K + y_offset;
|
||||||
|
|
||||||
B_TYPE_VEC2 b0 = data_b_v2[(b_offset + y_idx) / 2 + 0];
|
|
||||||
B_TYPE_VEC2 b16 = data_b_v2[(b_offset + y_idx) / 2 + 8];
|
|
||||||
B_TYPE_VEC2 b32 = data_b_v2[(b_offset + y_idx) / 2 + 16];
|
|
||||||
B_TYPE_VEC2 b48 = data_b_v2[(b_offset + y_idx) / 2 + 24];
|
|
||||||
B_TYPE_VEC2 b64 = data_b_v2[(b_offset + y_idx) / 2 + 32];
|
|
||||||
B_TYPE_VEC2 b80 = data_b_v2[(b_offset + y_idx) / 2 + 40];
|
|
||||||
B_TYPE_VEC2 b96 = data_b_v2[(b_offset + y_idx) / 2 + 48];
|
|
||||||
B_TYPE_VEC2 b112 = data_b_v2[(b_offset + y_idx) / 2 + 56];
|
|
||||||
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
|
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
|
||||||
|
@ -70,39 +58,34 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
u8vec2 s8 = unpack8(s8_16);
|
u8vec2 s8 = unpack8(s8_16);
|
||||||
u8vec2 s10 = unpack8(s10_16);
|
u8vec2 s10 = unpack8(s10_16);
|
||||||
|
|
||||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
[[unroll]] for (int l = 0; l < 2; ++l) {
|
|
||||||
sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
|
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0];
|
||||||
fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
|
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
|
||||||
fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
|
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
|
||||||
fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)),
|
B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
|
||||||
fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
|
B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
|
||||||
fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
|
B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
|
||||||
fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
|
B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
|
||||||
fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
|
B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
|
||||||
|
|
||||||
|
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||||
|
[[unroll]] for (int l = 0; l < 2; ++l) {
|
||||||
|
sum = fma(FLOAT_TYPE(b0[l]) * FLOAT_TYPE(int8_t(((s0[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 0)) != 0) ? 0 : 4)),
|
||||||
|
fma(FLOAT_TYPE(b32[l]) * FLOAT_TYPE(int8_t(((s2[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
|
||||||
|
fma(FLOAT_TYPE(b64[l]) * FLOAT_TYPE(int8_t(((s4[0] >> s_shift) & 0xF) | ((s8[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
|
||||||
|
fma(FLOAT_TYPE(b96[l]) * FLOAT_TYPE(int8_t(((s6[0] >> s_shift) & 0xF) | ((s10[0] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)),
|
||||||
|
fma(FLOAT_TYPE(b16[l]) * FLOAT_TYPE(int8_t(((s0[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 0)) != 0) ? 0 : 4)),
|
||||||
|
fma(FLOAT_TYPE(b48[l]) * FLOAT_TYPE(int8_t(((s2[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 0) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
|
||||||
|
fma(FLOAT_TYPE(b80[l]) * FLOAT_TYPE(int8_t(((s4[1] >> s_shift) & 0xF) | ((s8[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
|
||||||
|
fma(FLOAT_TYPE(b112[l]) * FLOAT_TYPE(int8_t(((s6[1] >> s_shift) & 0xF) | ((s10[1] >> (s_shift + 2) & 0x3) << 4)) - 32), FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
|
||||||
|
}
|
||||||
|
temp[j][n] = fma(d, sum, temp[j][n]);
|
||||||
}
|
}
|
||||||
temp[n] = fma(d, sum, temp[n]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
tmpsh[n][tid] = temp[n];
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
|
||||||
if (tid < s) {
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
tmpsh[n][tid] += tmpsh[n][tid + s];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
}
|
|
||||||
if (tid == 0) {
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
|
|
@ -6,11 +6,6 @@
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
|
||||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
|
||||||
|
|
||||||
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
|
|
||||||
|
|
||||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
uint a_offset, b_offset, d_offset;
|
uint a_offset, b_offset, d_offset;
|
||||||
get_offsets(a_offset, b_offset, d_offset);
|
get_offsets(a_offset, b_offset, d_offset);
|
||||||
|
@ -36,21 +31,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
const uint q_offset = 32*v_im + l0;
|
const uint q_offset = 32*v_im + l0;
|
||||||
const uint y_offset = 64*v_im + l0;
|
const uint y_offset = 64*v_im + l0;
|
||||||
|
|
||||||
FLOAT_TYPE temp[NUM_ROWS];
|
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||||
|
|
||||||
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
temp[i] = FLOAT_TYPE(0);
|
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||||
|
temp[j][i] = FLOAT_TYPE(0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
|
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
|
||||||
const uint y1_idx = i * QUANT_K + y_offset;
|
const uint y1_idx = i * QUANT_K + y_offset;
|
||||||
const uint y2_idx = y1_idx + 128;
|
const uint y2_idx = y1_idx + 128;
|
||||||
|
|
||||||
B_TYPE_VEC4 by10 = data_b_v4[(b_offset + y1_idx) / 4];
|
|
||||||
B_TYPE_VEC4 by132 = data_b_v4[(b_offset + y1_idx) / 4 + 8];
|
|
||||||
B_TYPE_VEC4 by20 = data_b_v4[(b_offset + y2_idx) / 4];
|
|
||||||
B_TYPE_VEC4 by232 = data_b_v4[(b_offset + y2_idx) / 4 + 8];
|
|
||||||
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||||
f16vec2 d = data_a[ib0 + i].d;
|
f16vec2 d = data_a[ib0 + i].d;
|
||||||
|
@ -103,37 +95,27 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
const uint32_t q4_14 = qs64_hi4.z;
|
const uint32_t q4_14 = qs64_hi4.z;
|
||||||
const uint32_t q4_15 = qs64_hi4.w;
|
const uint32_t q4_15 = qs64_hi4.w;
|
||||||
|
|
||||||
const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
|
B_TYPE_VEC4 by10 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4];
|
||||||
const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11)));
|
B_TYPE_VEC4 by132 = data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8];
|
||||||
const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
|
B_TYPE_VEC4 by20 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4];
|
||||||
const FLOAT_TYPE smin =
|
B_TYPE_VEC4 by232 = data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8];
|
||||||
fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
|
|
||||||
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
|
const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
|
||||||
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
|
const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
|
||||||
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
|
const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11)));
|
||||||
temp[n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[n]));
|
const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
|
||||||
|
const FLOAT_TYPE smin =
|
||||||
|
fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
|
||||||
|
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
|
||||||
|
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
|
||||||
|
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
|
||||||
|
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
tmpsh[n][tid] = temp[n];
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
|
||||||
if (tid < s) {
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
tmpsh[n][tid] += tmpsh[n][tid + s];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
}
|
|
||||||
if (tid == 0) {
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
|
|
@ -6,11 +6,6 @@
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
|
||||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
|
||||||
|
|
||||||
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
|
|
||||||
|
|
||||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
uint a_offset, b_offset, d_offset;
|
uint a_offset, b_offset, d_offset;
|
||||||
get_offsets(a_offset, b_offset, d_offset);
|
get_offsets(a_offset, b_offset, d_offset);
|
||||||
|
@ -33,25 +28,18 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
const uint q_offset = 32*v_im + l0;
|
const uint q_offset = 32*v_im + l0;
|
||||||
const uint y_offset = 64*v_im + l0;
|
const uint y_offset = 64*v_im + l0;
|
||||||
|
|
||||||
FLOAT_TYPE temp[NUM_ROWS];
|
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||||
|
|
||||||
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
temp[i] = FLOAT_TYPE(0);
|
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||||
|
temp[j][i] = FLOAT_TYPE(0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
|
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
|
||||||
const uint y1_idx = i * QUANT_K + y_offset;
|
const uint y1_idx = i * QUANT_K + y_offset;
|
||||||
const uint y2_idx = y1_idx + 128;
|
const uint y2_idx = y1_idx + 128;
|
||||||
|
|
||||||
B_TYPE_VEC2 by10 = data_b_v2[(b_offset + y1_idx) / 2];
|
|
||||||
B_TYPE_VEC2 by116 = data_b_v2[(b_offset + y1_idx) / 2 + 8];
|
|
||||||
B_TYPE_VEC2 by132 = data_b_v2[(b_offset + y1_idx) / 2 + 16];
|
|
||||||
B_TYPE_VEC2 by148 = data_b_v2[(b_offset + y1_idx) / 2 + 24];
|
|
||||||
B_TYPE_VEC2 by20 = data_b_v2[(b_offset + y2_idx) / 2];
|
|
||||||
B_TYPE_VEC2 by216 = data_b_v2[(b_offset + y2_idx) / 2 + 8];
|
|
||||||
B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
|
|
||||||
B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];
|
|
||||||
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||||
f16vec2 d = data_a[ib0 + i].d;
|
f16vec2 d = data_a[ib0 + i].d;
|
||||||
|
@ -116,53 +104,47 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
const uint32_t q4_14 = qs64_80_hi4.z;
|
const uint32_t q4_14 = qs64_80_hi4.z;
|
||||||
const uint32_t q4_15 = qs64_80_hi4.w;
|
const uint32_t q4_15 = qs64_80_hi4.w;
|
||||||
|
|
||||||
const FLOAT_TYPE sx =
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
fma(FLOAT_TYPE(by10.x), q4_0,
|
B_TYPE_VEC2 by10 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2];
|
||||||
fma(FLOAT_TYPE(by10.y), q4_1,
|
B_TYPE_VEC2 by116 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8];
|
||||||
fma(FLOAT_TYPE(by116.x), q4_2,
|
B_TYPE_VEC2 by132 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16];
|
||||||
FLOAT_TYPE(by116.y) * q4_3)));
|
B_TYPE_VEC2 by148 = data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24];
|
||||||
const FLOAT_TYPE sy =
|
B_TYPE_VEC2 by20 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2];
|
||||||
fma(FLOAT_TYPE(by132.x), q4_4,
|
B_TYPE_VEC2 by216 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8];
|
||||||
fma(FLOAT_TYPE(by132.y), q4_5,
|
B_TYPE_VEC2 by232 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16];
|
||||||
fma(FLOAT_TYPE(by148.x), q4_6,
|
B_TYPE_VEC2 by248 = data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24];
|
||||||
FLOAT_TYPE(by148.y) * q4_7)));
|
|
||||||
const FLOAT_TYPE sz =
|
const FLOAT_TYPE sx =
|
||||||
fma(FLOAT_TYPE(by20.x), q4_8,
|
fma(FLOAT_TYPE(by10.x), q4_0,
|
||||||
fma(FLOAT_TYPE(by20.y), q4_9,
|
fma(FLOAT_TYPE(by10.y), q4_1,
|
||||||
fma(FLOAT_TYPE(by216.x), q4_10,
|
fma(FLOAT_TYPE(by116.x), q4_2,
|
||||||
FLOAT_TYPE(by216.y) * q4_11)));
|
FLOAT_TYPE(by116.y) * q4_3)));
|
||||||
const FLOAT_TYPE sw =
|
const FLOAT_TYPE sy =
|
||||||
fma(FLOAT_TYPE(by232.x), q4_12,
|
fma(FLOAT_TYPE(by132.x), q4_4,
|
||||||
fma(FLOAT_TYPE(by232.y), q4_13,
|
fma(FLOAT_TYPE(by132.y), q4_5,
|
||||||
fma(FLOAT_TYPE(by248.x), q4_14,
|
fma(FLOAT_TYPE(by148.x), q4_6,
|
||||||
FLOAT_TYPE(by248.y) * q4_15)));
|
FLOAT_TYPE(by148.y) * q4_7)));
|
||||||
const FLOAT_TYPE smin =
|
const FLOAT_TYPE sz =
|
||||||
fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
|
fma(FLOAT_TYPE(by20.x), q4_8,
|
||||||
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
|
fma(FLOAT_TYPE(by20.y), q4_9,
|
||||||
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
|
fma(FLOAT_TYPE(by216.x), q4_10,
|
||||||
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
|
FLOAT_TYPE(by216.y) * q4_11)));
|
||||||
temp[n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[n]));
|
const FLOAT_TYPE sw =
|
||||||
|
fma(FLOAT_TYPE(by232.x), q4_12,
|
||||||
|
fma(FLOAT_TYPE(by232.y), q4_13,
|
||||||
|
fma(FLOAT_TYPE(by248.x), q4_14,
|
||||||
|
FLOAT_TYPE(by248.y) * q4_15)));
|
||||||
|
const FLOAT_TYPE smin =
|
||||||
|
fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
|
||||||
|
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
|
||||||
|
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
|
||||||
|
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
|
||||||
|
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
tmpsh[n][tid] = temp[n];
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
|
||||||
if (tid < s) {
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
tmpsh[n][tid] += tmpsh[n][tid + s];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
}
|
|
||||||
if (tid == 0) {
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
|
|
@ -6,11 +6,6 @@
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
|
||||||
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
|
||||||
|
|
||||||
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
|
|
||||||
|
|
||||||
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
uint a_offset, b_offset, d_offset;
|
uint a_offset, b_offset, d_offset;
|
||||||
get_offsets(a_offset, b_offset, d_offset);
|
get_offsets(a_offset, b_offset, d_offset);
|
||||||
|
@ -36,20 +31,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
const uint s_offset = 8*v_im + is;
|
const uint s_offset = 8*v_im + is;
|
||||||
const uint y_offset = 128*v_im + l0;
|
const uint y_offset = 128*v_im + l0;
|
||||||
|
|
||||||
FLOAT_TYPE temp[NUM_ROWS];
|
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||||
|
|
||||||
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
temp[i] = FLOAT_TYPE(0);
|
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||||
|
temp[j][i] = FLOAT_TYPE(0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
|
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) {
|
||||||
const uint y_idx = i * QUANT_K + y_offset;
|
const uint y_idx = i * QUANT_K + y_offset;
|
||||||
|
|
||||||
B_TYPE_VEC4 by0 = data_b_v4[(b_offset + y_idx) / 4];
|
|
||||||
B_TYPE_VEC4 by32 = data_b_v4[(b_offset + y_idx) / 4 + 8];
|
|
||||||
B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16];
|
|
||||||
B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24];
|
|
||||||
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
|
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
|
||||||
|
@ -84,35 +76,25 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
uvec4 q2 = uvec4(unpack8(q2_u32));
|
uvec4 q2 = uvec4(unpack8(q2_u32));
|
||||||
uvec4 q3 = uvec4(unpack8(q3_u32));
|
uvec4 q3 = uvec4(unpack8(q3_u32));
|
||||||
|
|
||||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
[[unroll]] for (int l = 0; l < 4; ++l) {
|
B_TYPE_VEC4 by0 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4];
|
||||||
sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
|
B_TYPE_VEC4 by32 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8];
|
||||||
fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
|
B_TYPE_VEC4 by64 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16];
|
||||||
fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
|
B_TYPE_VEC4 by96 = data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24];
|
||||||
fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
|
|
||||||
|
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||||
|
[[unroll]] for (int l = 0; l < 4; ++l) {
|
||||||
|
sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
|
||||||
|
fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
|
||||||
|
fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
|
||||||
|
fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
|
||||||
|
}
|
||||||
|
temp[j][n] += sum * d;
|
||||||
}
|
}
|
||||||
temp[n] += sum * d;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
reduce_result(temp, d_offset, first_row, num_rows, tid);
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
tmpsh[n][tid] = temp[n];
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
|
||||||
if (tid < s) {
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
tmpsh[n][tid] += tmpsh[n][tid + s];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
barrier();
|
|
||||||
}
|
|
||||||
if (tid == 0) {
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
|
||||||
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
|
|
@ -24,5 +24,5 @@ void main() {
|
||||||
|
|
||||||
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
|
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03;
|
||||||
|
|
||||||
data_d[p.d_offset + dst_idx] = D_TYPE(is_src0 ? data_a[src0_idx] : 0.0f);
|
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,5 +22,5 @@ void main() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(data_a[src0_idx_mod(idx)]);
|
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx_mod(idx)]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ void main() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
data_d[p.d_offset + idx] = D_TYPE(FLOAT_TYPE(data_a[idx]) * FLOAT_TYPE(p.param1));
|
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1));
|
||||||
idx += num_threads;
|
idx += num_threads;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,6 @@ void main() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
|
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(sin(val));
|
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val));
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,6 @@ void main() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const FLOAT_TYPE val = FLOAT_TYPE(data_a[src0_idx(idx)]);
|
const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
|
||||||
data_d[p.d_offset + dst_idx(idx)] = D_TYPE(val * val);
|
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
layout (push_constant) uniform parameter
|
||||||
{
|
{
|
||||||
uint ne; uint d_offset;
|
uint ne; uint a_offset; uint d_offset;
|
||||||
uint nb00; uint nb01; uint nb02; uint nb03;
|
uint nb00; uint nb01; uint nb02; uint nb03;
|
||||||
uint ne10; uint ne11; uint ne12; uint ne13;
|
uint ne10; uint ne11; uint ne12; uint ne13;
|
||||||
float sf0; float sf1; float sf2; float sf3;
|
float sf0; float sf1; float sf2; float sf3;
|
||||||
|
@ -32,5 +32,5 @@ void main() {
|
||||||
const uint i02 = uint(i12 / p.sf2);
|
const uint i02 = uint(i12 / p.sf2);
|
||||||
const uint i03 = uint(i13 / p.sf3);
|
const uint i03 = uint(i13 / p.sf3);
|
||||||
|
|
||||||
data_d[p.d_offset + idx] = D_TYPE(data_a[i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
|
data_d[p.d_offset + idx] = D_TYPE(data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue