mirror of
https://github.com/LostRuins/koboldcpp.git
synced 2025-09-12 18:09:42 +00:00
Merge commit 'c917b67f06
' into concedo_experimental
# Conflicts: # .devops/tools.sh # Makefile # ggml/src/ggml-cuda/mmq.cuh # tests/test-double-float.cpp # tests/test-quantize-fns.cpp # tests/test-quantize-perf.cpp
This commit is contained in:
commit
602661ba49
25 changed files with 1339 additions and 1504 deletions
|
@ -1203,11 +1203,10 @@ class RefactModel(Model):
|
||||||
|
|
||||||
# TODO: how to determine special FIM tokens automatically?
|
# TODO: how to determine special FIM tokens automatically?
|
||||||
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
|
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
|
||||||
special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot'])
|
special_token_types = ['prefix', 'suffix', 'middle', 'eot'])
|
||||||
special_vocab._set_special_token("prefix", 1)
|
special_vocab._set_special_token("prefix", 1)
|
||||||
special_vocab._set_special_token("suffix", 3)
|
special_vocab._set_special_token("suffix", 3)
|
||||||
special_vocab._set_special_token("middle", 2)
|
special_vocab._set_special_token("middle", 2)
|
||||||
special_vocab._set_special_token("fsep", 4) # is this correct?
|
|
||||||
special_vocab.add_to_gguf(self.gguf_writer)
|
special_vocab.add_to_gguf(self.gguf_writer)
|
||||||
|
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
|
|
|
@ -99,7 +99,7 @@ static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
|
||||||
|
|
||||||
char src1_str[128] = {0};
|
char src1_str[128] = {0};
|
||||||
if (src1) {
|
if (src1) {
|
||||||
sprintf(src1_str, "%s{%s}", src1->name, ggml_ne_string(src1).c_str());
|
snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
printf("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__,
|
printf("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__,
|
||||||
|
|
|
@ -347,7 +347,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
|
||||||
char hex_result[17];
|
char hex_result[17];
|
||||||
for (int offset = 0; offset < 8; offset++) {
|
for (int offset = 0; offset < 8; offset++) {
|
||||||
unsigned int shift_bits_by = (8 * (8 - offset - 1));
|
unsigned int shift_bits_by = (8 * (8 - offset - 1));
|
||||||
sprintf( ( hex_result + (2*offset)), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff);
|
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hash_params.manifest_is_usable) {
|
if (hash_params.manifest_is_usable) {
|
||||||
|
@ -384,7 +384,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
|
||||||
|
|
||||||
char hex_result[41] = {0};
|
char hex_result[41] = {0};
|
||||||
for (int offset = 0; offset < 20; offset++) {
|
for (int offset = 0; offset < 20; offset++) {
|
||||||
sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff);
|
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hash_params.manifest_is_usable) {
|
if (hash_params.manifest_is_usable) {
|
||||||
|
@ -421,7 +421,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
|
||||||
|
|
||||||
char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0};
|
char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0};
|
||||||
for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) {
|
for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) {
|
||||||
sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff);
|
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hash_params.manifest_is_usable) {
|
if (hash_params.manifest_is_usable) {
|
||||||
|
@ -460,7 +460,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
|
||||||
char hex_result[17];
|
char hex_result[17];
|
||||||
for (int offset = 0; offset < 8; offset++) {
|
for (int offset = 0; offset < 8; offset++) {
|
||||||
unsigned int shift_bits_by = (8 * (8 - offset - 1));
|
unsigned int shift_bits_by = (8 * (8 - offset - 1));
|
||||||
sprintf( ( hex_result + (2*offset)), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff);
|
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", (unsigned char) (hash >> shift_bits_by)&0xff);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hash_params.manifest_is_usable) {
|
if (hash_params.manifest_is_usable) {
|
||||||
|
@ -490,7 +490,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
|
||||||
|
|
||||||
char hex_result[41];
|
char hex_result[41];
|
||||||
for (int offset = 0; offset < 20; offset++) {
|
for (int offset = 0; offset < 20; offset++) {
|
||||||
sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff);
|
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hash_params.manifest_is_usable) {
|
if (hash_params.manifest_is_usable) {
|
||||||
|
@ -520,7 +520,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
|
||||||
|
|
||||||
char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0};
|
char hex_result[SHA256_DIGEST_SIZE * 2 + 1] = {0};
|
||||||
for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) {
|
for (int offset = 0; offset < SHA256_DIGEST_SIZE; offset++) {
|
||||||
sprintf( ( hex_result + (2*offset)), "%02x", result[offset]&0xff);
|
snprintf( ( hex_result + (2*offset)), sizeof(hex_result) - (2*offset), "%02x", result[offset]&0xff);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (hash_params.manifest_is_usable) {
|
if (hash_params.manifest_is_usable) {
|
||||||
|
@ -552,7 +552,7 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) {
|
||||||
generate_uuidv5(result, uuid);
|
generate_uuidv5(result, uuid);
|
||||||
|
|
||||||
char string_buffer[37] = {0};
|
char string_buffer[37] = {0};
|
||||||
sprintf(string_buffer, "%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
|
snprintf(string_buffer, sizeof(string_buffer), "%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x",
|
||||||
uuid[0], uuid[1], uuid[2], uuid[3],
|
uuid[0], uuid[1], uuid[2], uuid[3],
|
||||||
uuid[4], uuid[5], uuid[6], uuid[7],
|
uuid[4], uuid[5], uuid[6], uuid[7],
|
||||||
uuid[8], uuid[9], uuid[10], uuid[11],
|
uuid[8], uuid[9], uuid[10], uuid[11],
|
||||||
|
|
|
@ -290,8 +290,13 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// Should not run without any tokens
|
// Should not run without any tokens
|
||||||
if (embd_inp.empty()) {
|
if (embd_inp.empty()) {
|
||||||
embd_inp.push_back(llama_token_bos(model));
|
if (add_bos) {
|
||||||
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
embd_inp.push_back(llama_token_bos(model));
|
||||||
|
LOG("embd_inp was considered empty and bos was added: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_inp).c_str());
|
||||||
|
} else {
|
||||||
|
LOG_TEE("error: input is empty\n");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tokenize negative prompt
|
// Tokenize negative prompt
|
||||||
|
|
|
@ -155,7 +155,7 @@ static void test_roundtrip_on_chunk(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (use_reference) {
|
if (use_reference) {
|
||||||
qfns.from_float_reference(input_scratch, quantized_scratch, chunk_size);
|
qfns.from_float_ref(input_scratch, quantized_scratch, chunk_size);
|
||||||
} else {
|
} else {
|
||||||
qfns.from_float(input_scratch, quantized_scratch, chunk_size);
|
qfns.from_float(input_scratch, quantized_scratch, chunk_size);
|
||||||
}
|
}
|
||||||
|
|
|
@ -2006,6 +2006,11 @@ struct server_context {
|
||||||
int32_t n_batch = llama_n_batch(ctx);
|
int32_t n_batch = llama_n_batch(ctx);
|
||||||
int32_t n_ubatch = llama_n_ubatch(ctx);
|
int32_t n_ubatch = llama_n_ubatch(ctx);
|
||||||
|
|
||||||
|
// track if this is an embedding or non-embedding batch
|
||||||
|
// if we've added sampled tokens above, we are in non-embedding mode
|
||||||
|
// -1: none, 0: non-embedding, 1: embedding
|
||||||
|
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
|
||||||
|
|
||||||
// next, batch any pending prompts without exceeding n_batch
|
// next, batch any pending prompts without exceeding n_batch
|
||||||
if (params.cont_batching || batch.n_tokens == 0) {
|
if (params.cont_batching || batch.n_tokens == 0) {
|
||||||
for (auto & slot : slots) {
|
for (auto & slot : slots) {
|
||||||
|
@ -2176,6 +2181,14 @@ struct server_context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check that we are in the right batch_type, if not defer the slot
|
||||||
|
bool slot_type = slot.embedding ? 1 : 0;
|
||||||
|
if (batch_type == -1) {
|
||||||
|
batch_type = slot_type;
|
||||||
|
} else if (batch_type != slot_type) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// keep only the common part
|
// keep only the common part
|
||||||
int p0 = (int) system_tokens.size() + slot.n_past;
|
int p0 = (int) system_tokens.size() + slot.n_past;
|
||||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
||||||
|
@ -2277,6 +2290,9 @@ struct server_context {
|
||||||
{"n_tokens", batch.n_tokens},
|
{"n_tokens", batch.n_tokens},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// make sure we're in the right embedding mode
|
||||||
|
llama_set_embeddings(ctx, batch_type == 1);
|
||||||
|
|
||||||
// process the created batch of tokens
|
// process the created batch of tokens
|
||||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||||
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i);
|
||||||
|
@ -2991,6 +3007,11 @@ int main(int argc, char ** argv) {
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_completions = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
if (ctx_server.params.embedding) {
|
||||||
|
res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
|
@ -3086,6 +3107,11 @@ int main(int argc, char ** argv) {
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_chat_completions = [&ctx_server, ¶ms, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
if (ctx_server.params.embedding) {
|
||||||
|
res_error(res, format_error_response("This server does not support chat completions. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template);
|
||||||
|
|
||||||
|
@ -3158,6 +3184,11 @@ int main(int argc, char ** argv) {
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_infill = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
if (ctx_server.params.embedding) {
|
||||||
|
res_error(res, format_error_response("This server does not support infill. Start it without `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
|
@ -3244,13 +3275,8 @@ int main(int argc, char ** argv) {
|
||||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_embeddings = [¶ms, &ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_embeddings = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
if (!params.embedding) {
|
|
||||||
res.status = 501;
|
|
||||||
res.set_content("This server does not support embeddings. Start it with `--embeddings`", "text/plain; charset=utf-8");
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
bool is_openai = false;
|
bool is_openai = false;
|
||||||
|
|
|
@ -122,8 +122,26 @@ inline std::string format_chat(const struct llama_model * model, const std::stri
|
||||||
|
|
||||||
for (size_t i = 0; i < messages.size(); ++i) {
|
for (size_t i = 0; i < messages.size(); ++i) {
|
||||||
const auto & curr_msg = messages[i];
|
const auto & curr_msg = messages[i];
|
||||||
std::string role = json_value(curr_msg, "role", std::string(""));
|
|
||||||
std::string content = json_value(curr_msg, "content", std::string(""));
|
std::string role = json_value(curr_msg, "role", std::string(""));
|
||||||
|
|
||||||
|
std::string content;
|
||||||
|
if (curr_msg.contains("content")) {
|
||||||
|
if (curr_msg["content"].is_string()) {
|
||||||
|
content = curr_msg["content"].get<std::string>();
|
||||||
|
} else if (curr_msg["content"].is_array()) {
|
||||||
|
for (const auto & part : curr_msg["content"]) {
|
||||||
|
if (part.contains("text")) {
|
||||||
|
content += "\n" + part["text"].get<std::string>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Invalid 'content' type (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Missing 'content' (ref: https://github.com/ggerganov/llama.cpp/issues/8367)");
|
||||||
|
}
|
||||||
|
|
||||||
chat.push_back({role, content});
|
chat.push_back({role, content});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -721,9 +721,9 @@ extern "C" {
|
||||||
GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor);
|
GGML_API GGML_CALL size_t ggml_nbytes (const struct ggml_tensor * tensor);
|
||||||
GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
|
GGML_API size_t ggml_nbytes_pad (const struct ggml_tensor * tensor); // same as ggml_nbytes() but padded to GGML_MEM_ALIGN
|
||||||
|
|
||||||
GGML_API GGML_CALL int ggml_blck_size(enum ggml_type type);
|
GGML_API GGML_CALL int64_t ggml_blck_size(enum ggml_type type);
|
||||||
GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
|
GGML_API GGML_CALL size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block
|
||||||
GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
|
GGML_API GGML_CALL size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row
|
||||||
|
|
||||||
GGML_DEPRECATED(
|
GGML_DEPRECATED(
|
||||||
GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
|
GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float
|
||||||
|
@ -2417,31 +2417,31 @@ extern "C" {
|
||||||
#endif
|
#endif
|
||||||
typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
typedef void (*ggml_to_float_t) (const void * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
|
||||||
typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
typedef void (*ggml_from_float_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||||
typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
|
typedef void (*ggml_from_float_to_mat_t)
|
||||||
const void * GGML_RESTRICT y, size_t by, int nrc);
|
(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr, int64_t k, int64_t bs);
|
||||||
typedef void (*ggml_from_float_to_mat_t)(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nr,
|
typedef void (*ggml_vec_dot_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x, size_t bx,
|
||||||
int64_t k, int64_t bx);
|
const void * GGML_RESTRICT y, size_t by, int nrc);
|
||||||
typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
|
typedef void (*ggml_gemv_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
|
||||||
const void * GGML_RESTRICT y, int nr, int nc);
|
const void * GGML_RESTRICT y, int nr, int nc);
|
||||||
typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
|
typedef void (*ggml_gemm_t) (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
|
||||||
const void * GGML_RESTRICT y, int nr, int nc);
|
const void * GGML_RESTRICT y, int nr, int nc);
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
const char * type_name;
|
const char * type_name;
|
||||||
int blck_size;
|
int64_t blck_size;
|
||||||
size_t type_size;
|
int64_t blck_size_interleave; // interleave elements in blocks
|
||||||
bool is_quantized;
|
size_t type_size;
|
||||||
ggml_to_float_t to_float;
|
bool is_quantized;
|
||||||
ggml_from_float_t from_float;
|
ggml_to_float_t to_float;
|
||||||
ggml_from_float_t from_float_reference;
|
ggml_from_float_t from_float;
|
||||||
ggml_vec_dot_t vec_dot;
|
ggml_from_float_t from_float_ref;
|
||||||
enum ggml_type vec_dot_type;
|
|
||||||
int64_t nrows; // number of rows to process simultaneously;
|
|
||||||
int64_t ncols; // number of columns to process simultaneously;
|
|
||||||
int64_t interleave_blcksize; // interleave elements in blocks of interleave_blcksize;
|
|
||||||
ggml_from_float_to_mat_t from_float_to_mat;
|
ggml_from_float_to_mat_t from_float_to_mat;
|
||||||
ggml_gemv_t gemv;
|
ggml_vec_dot_t vec_dot;
|
||||||
ggml_gemm_t gemm;
|
enum ggml_type vec_dot_type;
|
||||||
|
int64_t nrows; // number of rows to process simultaneously
|
||||||
|
int64_t ncols; // number of columns to process simultaneously
|
||||||
|
ggml_gemv_t gemv;
|
||||||
|
ggml_gemm_t gemm;
|
||||||
} ggml_type_traits_t;
|
} ggml_type_traits_t;
|
||||||
|
|
||||||
GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
|
GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
|
||||||
|
|
|
@ -20,19 +20,19 @@
|
||||||
|
|
||||||
// Functions to create the interleaved data layout formats
|
// Functions to create the interleaved data layout formats
|
||||||
|
|
||||||
// interleave 4 block_q4_0s in blocks of interleave_blcksize
|
// interleave 4 block_q4_0s in blocks of blck_size_interleave
|
||||||
// returns an interleaved block_q4_0x4
|
// returns an interleaved block_q4_0x4
|
||||||
// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks
|
// in the interleaved block_q4_0x4, place deltas for 4 block_q4_0 blocks
|
||||||
// first, then interleave quants from 4 block_q4_0s in blocks of interleave_blcksize
|
// first, then interleave quants from 4 block_q4_0s in blocks of blck_size_interleave
|
||||||
//
|
//
|
||||||
// - in : an array of block_q4_0 pointers
|
// - in : an array of block_q4_0 pointers
|
||||||
// - interleave_blcksize : the block_q4_0 quants bytes are interleaved in blocks of
|
// - blck_size_interleave : the block_q4_0 quants bytes are interleaved in blocks of
|
||||||
// interleave_blcksize bytes
|
// blck_size_interleave bytes
|
||||||
// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes
|
// - xor_mask : the mask to convert the nibbles in block_q4_0 quants bytes
|
||||||
// from bias offset form to pure sign form (this saves subtract
|
// from bias offset form to pure sign form (this saves subtract
|
||||||
// operations durin unpacking)
|
// operations durin unpacking)
|
||||||
//
|
//
|
||||||
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int interleave_blcksize, unsigned int xor_mask) {
|
static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
|
||||||
block_q4_0x4 out;
|
block_q4_0x4 out;
|
||||||
|
|
||||||
for (int i = 0; i < 4; i++) {
|
for (int i = 0; i < 4; i++) {
|
||||||
|
@ -40,9 +40,9 @@ static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int interleave_b
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < QK4_0 * 2; i++) {
|
for (int i = 0; i < QK4_0 * 2; i++) {
|
||||||
int src_offset = (i / (4 * interleave_blcksize)) * interleave_blcksize;
|
int src_offset = (i / (4 * blck_size_interleave)) * blck_size_interleave;
|
||||||
int src_id = (i % (4 * interleave_blcksize)) / interleave_blcksize;
|
int src_id = (i % (4 * blck_size_interleave)) / blck_size_interleave;
|
||||||
src_offset += (i % interleave_blcksize);
|
src_offset += (i % blck_size_interleave);
|
||||||
|
|
||||||
out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
|
out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
|
||||||
}
|
}
|
||||||
|
@ -50,11 +50,11 @@ static block_q4_0x4 make_block_q4_0x4(block_q4_0 * in, unsigned int interleave_b
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
// interleave 8 block_q4_0s in blocks of interleave_blcksize
|
// interleave 8 block_q4_0s in blocks of blck_size_interleave
|
||||||
// returns an interleaved block_q4_0x8
|
// returns an interleaved block_q4_0x8
|
||||||
// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
|
// in the interleaved block_q4_0x8, place deltas for 8 block_q4_0 blocks
|
||||||
// first, then interleave quants from 8 block_q4_0s in blocks of interleave_blcksize
|
// first, then interleave quants from 8 block_q4_0s in blocks of blck_size_interleave
|
||||||
static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int interleave_blcksize, unsigned int xor_mask) {
|
static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int blck_size_interleave, unsigned int xor_mask) {
|
||||||
block_q4_0x8 out;
|
block_q4_0x8 out;
|
||||||
|
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
|
@ -62,9 +62,9 @@ static block_q4_0x8 make_block_q4_0x8(block_q4_0 * in, unsigned int interleave_b
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < QK4_0 * 4; i++) {
|
for (int i = 0; i < QK4_0 * 4; i++) {
|
||||||
int src_offset = (i / (8 * interleave_blcksize)) * interleave_blcksize;
|
int src_offset = (i / (8 * blck_size_interleave)) * blck_size_interleave;
|
||||||
int src_id = (i % (8 * interleave_blcksize)) / interleave_blcksize;
|
int src_id = (i % (8 * blck_size_interleave)) / blck_size_interleave;
|
||||||
src_offset += (i % interleave_blcksize);
|
src_offset += (i % blck_size_interleave);
|
||||||
|
|
||||||
out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
|
out.qs[i] = in[src_id].qs[src_offset] ^ xor_mask;
|
||||||
}
|
}
|
||||||
|
@ -135,7 +135,7 @@ void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k)
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
// scalar
|
// scalar
|
||||||
const int interleave_blcksize = 4;
|
const int blck_size_interleave = 4;
|
||||||
float srcv[4][QK8_0];
|
float srcv[4][QK8_0];
|
||||||
float id[4];
|
float id[4];
|
||||||
|
|
||||||
|
@ -155,12 +155,12 @@ void quantize_q8_0_4x4(const float * restrict x, void * restrict vy, int64_t k)
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int j = 0; j < QK8_0 * 4; j++) {
|
for (int j = 0; j < QK8_0 * 4; j++) {
|
||||||
int src_offset = (j / (4 * interleave_blcksize)) * interleave_blcksize;
|
int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
|
||||||
int src_id = (j % (4 * interleave_blcksize)) / interleave_blcksize;
|
int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
|
||||||
src_offset += (j % interleave_blcksize);
|
src_offset += (j % blck_size_interleave);
|
||||||
|
|
||||||
float x0 = srcv[src_id][src_offset] * id[src_id];
|
float x0 = srcv[src_id][src_offset] * id[src_id];
|
||||||
y[i].qs[j] = roundf(x0);;
|
y[i].qs[j] = roundf(x0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -253,7 +253,7 @@ void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k)
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
// scalar
|
// scalar
|
||||||
const int interleave_blcksize = 8;
|
const int blck_size_interleave = 8;
|
||||||
float srcv[4][QK8_0];
|
float srcv[4][QK8_0];
|
||||||
float id[4];
|
float id[4];
|
||||||
|
|
||||||
|
@ -273,26 +273,30 @@ void quantize_q8_0_4x8(const float * restrict x, void * restrict vy, int64_t k)
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int j = 0; j < QK8_0 * 4; j++) {
|
for (int j = 0; j < QK8_0 * 4; j++) {
|
||||||
int src_offset = (j / (4 * interleave_blcksize)) * interleave_blcksize;
|
int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
|
||||||
int src_id = (j % (4 * interleave_blcksize)) / interleave_blcksize;
|
int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
|
||||||
src_offset += (j % interleave_blcksize);
|
src_offset += (j % blck_size_interleave);
|
||||||
|
|
||||||
float x0 = srcv[src_id][src_offset] * id[src_id];
|
float x0 = srcv[src_id][src_offset] * id[src_id];
|
||||||
y[i].qs[j] = roundf(x0);;
|
y[i].qs[j] = roundf(x0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t interleave_blcksize) {
|
void quantize_mat_q8_0(const float * restrict x, void * restrict vy, int64_t nrow, int64_t n_per_row, int64_t blck_size_interleave) {
|
||||||
assert(nrow == 4);
|
assert(nrow == 4);
|
||||||
UNUSED(nrow);
|
UNUSED(nrow);
|
||||||
if (interleave_blcksize == 4) quantize_q8_0_4x4(x, vy, n_per_row);
|
if (blck_size_interleave == 4) {
|
||||||
else if (interleave_blcksize == 8) quantize_q8_0_4x8(x, vy, n_per_row);
|
quantize_q8_0_4x4(x, vy, n_per_row);
|
||||||
else assert(false);
|
} else if (blck_size_interleave == 8) {
|
||||||
|
quantize_q8_0_4x8(x, vy, n_per_row);
|
||||||
|
} else {
|
||||||
|
assert(false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int interleave_blcksize) {
|
static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, int nrows_interleaved, int blck_size_interleave) {
|
||||||
assert(n_per_row % QK4_0 == 0);
|
assert(n_per_row % QK4_0 == 0);
|
||||||
const int nb = n_per_row / QK4_0;
|
const int nb = n_per_row / QK4_0;
|
||||||
|
|
||||||
|
@ -311,15 +315,15 @@ static size_t quantize_q4_0_nr_bl(const float * restrict src, void * restrict ds
|
||||||
for (int64_t x = 0; x < nb; x++) {
|
for (int64_t x = 0; x < nb; x++) {
|
||||||
|
|
||||||
for (int i = 0; i < nrows_interleaved; i++ ) {
|
for (int i = 0; i < nrows_interleaved; i++ ) {
|
||||||
quantize_row_q4_0_reference(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0);
|
quantize_row_q4_0_ref(src + b + i * n_per_row + x * QK4_0, (block_q4_0 *) dst_tmp + i, QK4_0);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (nrows_interleaved == 8) {
|
if (nrows_interleaved == 8) {
|
||||||
*(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, interleave_blcksize, 0x88);
|
*(block_q4_0x8 *) out_ptr = make_block_q4_0x8(dst_tmp, blck_size_interleave, 0x88);
|
||||||
out_ptr = (block_q4_0x8 *) out_ptr + 1;
|
out_ptr = (block_q4_0x8 *) out_ptr + 1;
|
||||||
}
|
}
|
||||||
else if (nrows_interleaved == 4) {
|
else if (nrows_interleaved == 4) {
|
||||||
*(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, interleave_blcksize, 0x88);
|
*(block_q4_0x4 *) out_ptr = make_block_q4_0x4(dst_tmp, blck_size_interleave, 0x88);
|
||||||
out_ptr = (block_q4_0x4 *) out_ptr + 1;
|
out_ptr = (block_q4_0x4 *) out_ptr + 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,7 @@ extern "C" {
|
||||||
void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
void quantize_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
void quantize_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||||
|
|
||||||
void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t interleave_blcksize);
|
void quantize_mat_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t nrows, int64_t n_per_row, int64_t blck_size_interleave);
|
||||||
|
|
||||||
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
|
||||||
size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
size_t quantize_q4_0_4x4(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||||
|
@ -24,14 +24,14 @@ size_t quantize_q4_0_4x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT d
|
||||||
size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
size_t quantize_q4_0_8x8(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);
|
||||||
|
|
||||||
// GEMV
|
// GEMV
|
||||||
void ggml_gemv_q4_0_4x4_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q4_0_4x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemv_q4_0_8x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
|
||||||
// GEMM
|
// GEMM
|
||||||
void ggml_gemm_q4_0_4x4_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_4x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
void ggml_gemm_q4_0_8x8_q8_0 (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
}
|
}
|
||||||
|
|
|
@ -394,7 +394,7 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
|
||||||
|
|
||||||
// backend registry
|
// backend registry
|
||||||
|
|
||||||
#define GGML_REG_MAX_BACKENDS 16
|
#define GGML_REG_MAX_BACKENDS 64
|
||||||
|
|
||||||
struct ggml_backend_reg {
|
struct ggml_backend_reg {
|
||||||
char name[128];
|
char name[128];
|
||||||
|
|
|
@ -8,11 +8,12 @@
|
||||||
# include <Accelerate/Accelerate.h>
|
# include <Accelerate/Accelerate.h>
|
||||||
#elif defined(GGML_BLAS_USE_MKL)
|
#elif defined(GGML_BLAS_USE_MKL)
|
||||||
# include <mkl.h>
|
# include <mkl.h>
|
||||||
|
#elif defined(GGML_BLAS_USE_BLIS)
|
||||||
|
# include <blis.h>
|
||||||
|
#elif defined(GGML_BLAS_USE_NVPL)
|
||||||
|
# include <nvpl_blas.h>
|
||||||
#else
|
#else
|
||||||
# include <cblas.h>
|
# include <cblas.h>
|
||||||
# ifdef BLIS_ENABLE_CBLAS
|
|
||||||
# include <blis.h>
|
|
||||||
# endif
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
struct ggml_backend_blas_context {
|
struct ggml_backend_blas_context {
|
||||||
|
@ -140,10 +141,14 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
|
||||||
openblas_set_num_threads(ctx->n_threads);
|
openblas_set_num_threads(ctx->n_threads);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(BLIS_ENABLE_CBLAS)
|
#if defined(GGML_BLAS_USE_BLIS)
|
||||||
bli_thread_set_num_threads(ctx->n_threads);
|
bli_thread_set_num_threads(ctx->n_threads);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(GGML_BLAS_USE_NVPL)
|
||||||
|
nvpl_blas_set_num_threads(ctx->n_threads);
|
||||||
|
#endif
|
||||||
|
|
||||||
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
for (int64_t i13 = 0; i13 < ne13; i13++) {
|
||||||
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
for (int64_t i12 = 0; i12 < ne12; i12++) {
|
||||||
const int64_t i03 = i13/r3;
|
const int64_t i03 = i13/r3;
|
||||||
|
|
|
@ -104,7 +104,7 @@
|
||||||
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
|
||||||
#define cudaStream_t hipStream_t
|
#define cudaStream_t hipStream_t
|
||||||
#define cudaSuccess hipSuccess
|
#define cudaSuccess hipSuccess
|
||||||
#define __trap abort
|
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
||||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||||
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
||||||
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
|
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
|
||||||
|
|
|
@ -70,6 +70,10 @@ struct mma_int_A_I16K8 {
|
||||||
}
|
}
|
||||||
#endif // defined(INT8_MMA_AVAILABLE)
|
#endif // defined(INT8_MMA_AVAILABLE)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
__device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
|
||||||
|
((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct mma_int_B_J8K4 {
|
struct mma_int_B_J8K4 {
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -37,47 +37,92 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
|
||||||
reinterpret_cast<half&>(y[ib].ds.y) = sum;
|
reinterpret_cast<half&>(y[ib].ds.y) = sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool need_sum>
|
template <mmq_q8_1_ds_layout ds_layout>
|
||||||
static __global__ void quantize_mmq_q8_1(
|
static __global__ void quantize_mmq_q8_1(
|
||||||
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
|
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
|
||||||
|
|
||||||
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
|
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
|
||||||
|
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
|
||||||
|
|
||||||
|
const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
|
||||||
|
|
||||||
if (ix0 >= kx0_padded) {
|
if (ix0 >= kx0_padded) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const float4 * x4 = (const float4 *) x;
|
||||||
|
|
||||||
const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
|
const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
|
||||||
|
|
||||||
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
|
||||||
|
|
||||||
const int64_t ib0 = blockIdx.z*(gridDim.y*gridDim.x*blockDim.x/(4*QK8_1)); // first block of channel
|
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
|
||||||
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
|
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
|
||||||
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
|
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
|
||||||
|
|
||||||
const float xi = ix0 < kx0 ? x[ix1*kx0 + ix0] : 0.0f;
|
// Load 4 floats per thread and calculate max. abs. value between them:
|
||||||
float amax = fabsf(xi);
|
const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
|
||||||
|
float amax = fabsf(xi.x);
|
||||||
|
amax = fmaxf(amax, fabsf(xi.y));
|
||||||
|
amax = fmaxf(amax, fabsf(xi.z));
|
||||||
|
amax = fmaxf(amax, fabsf(xi.w));
|
||||||
|
|
||||||
amax = warp_reduce_max(amax);
|
// Exchange max. abs. value between vals_per_scale/4 threads.
|
||||||
|
#pragma unroll
|
||||||
float sum;
|
for (int mask = vals_per_scale/8; mask > 0; mask >>= 1) {
|
||||||
if (need_sum) {
|
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
|
||||||
sum = warp_reduce_sum(xi);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const float d = amax / 127;
|
float sum;
|
||||||
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
|
if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
|
||||||
|
sum = xi.x + xi.y + xi.z + xi.w;
|
||||||
|
|
||||||
y[ib].qs[iqs] = q;
|
// Exchange calculate sum across vals_per_sum/4 threads.
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = vals_per_sum/8; mask > 0; mask >>= 1) {
|
||||||
|
sum += __shfl_xor_sync(0xFFFFFFFF, sum, mask, WARP_SIZE);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d_inv = 127.0f / amax;
|
||||||
|
char4 q;
|
||||||
|
q.x = roundf(xi.x*d_inv);
|
||||||
|
q.y = roundf(xi.y*d_inv);
|
||||||
|
q.z = roundf(xi.z*d_inv);
|
||||||
|
q.w = roundf(xi.w*d_inv);
|
||||||
|
|
||||||
|
// Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
|
||||||
|
char4 * yqs4 = (char4 *) y[ib].qs;
|
||||||
|
yqs4[iqs/4] = q;
|
||||||
|
|
||||||
|
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
|
||||||
|
if (iqs % 16 != 0 || iqs >= 96) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
y[ib].d2s6[2 + iqs/16] = sum;
|
||||||
|
|
||||||
|
if (iqs % 64 != 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = 1.0f / d_inv;
|
||||||
|
|
||||||
|
y[ib].d2s6[iqs/64] = d;
|
||||||
|
|
||||||
if (iqs % QK8_1 != 0) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (need_sum) {
|
if (iqs % 32 != 0) {
|
||||||
y[ib].ds[iqs/QK8_1] = make_half2(d, sum);
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float d = 1.0f / d_inv;
|
||||||
|
|
||||||
|
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
|
||||||
|
y[ib].ds4[iqs/32] = make_half2(d, sum);
|
||||||
} else {
|
} else {
|
||||||
((float *) y[ib].ds)[iqs/QK8_1] = d;
|
y[ib].d4[iqs/32] = d;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -101,12 +146,24 @@ void quantize_mmq_q8_1_cuda(
|
||||||
|
|
||||||
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
|
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
|
||||||
|
|
||||||
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
|
const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
|
||||||
const dim3 num_blocks(block_num_x, kx1, channels);
|
const dim3 num_blocks(block_num_x, kx1, channels);
|
||||||
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
|
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
|
||||||
if (mmq_need_sum(type_x)) {
|
switch (mmq_get_q8_1_ds_layout(type_x)) {
|
||||||
quantize_mmq_q8_1<true><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
case MMQ_Q8_1_DS_LAYOUT_D4:
|
||||||
} else {
|
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
|
||||||
quantize_mmq_q8_1<false><<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
||||||
|
break;
|
||||||
|
case MMQ_Q8_1_DS_LAYOUT_DS4:
|
||||||
|
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
|
||||||
|
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
||||||
|
break;
|
||||||
|
case MMQ_Q8_1_DS_LAYOUT_D2S6:
|
||||||
|
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
|
||||||
|
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,11 @@
|
||||||
|
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||||
|
#define CUDA_QUANTIZE_BLOCK_SIZE_MMQ 128
|
||||||
|
|
||||||
|
static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk of out-of-bounds access.");
|
||||||
|
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
|
||||||
|
|
||||||
typedef void (*quantize_cuda_t)(
|
typedef void (*quantize_cuda_t)(
|
||||||
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
|
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded,
|
||||||
|
|
|
@ -189,7 +189,7 @@ template <int vdr> static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp
|
||||||
}
|
}
|
||||||
|
|
||||||
#define VDR_Q2_K_Q8_1_MMVQ 1
|
#define VDR_Q2_K_Q8_1_MMVQ 1
|
||||||
#define VDR_Q2_K_Q8_1_MMQ 2
|
#define VDR_Q2_K_Q8_1_MMQ 4
|
||||||
|
|
||||||
// contiguous v/x values
|
// contiguous v/x values
|
||||||
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
|
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
|
||||||
|
@ -219,32 +219,56 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq(
|
||||||
return dm2f.x*sumf_d - dm2f.y*sumf_m;
|
return dm2f.x*sumf_d - dm2f.y*sumf_m;
|
||||||
}
|
}
|
||||||
|
|
||||||
// contiguous u/y values
|
// contiguous v/x + u/y values
|
||||||
|
template <int ns8>
|
||||||
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
|
static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq(
|
||||||
const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8) {
|
const int * __restrict__ v, const int * __restrict__ u, const half2 * dm2, const float & d8, const half2 * s8) {
|
||||||
|
|
||||||
float sumf_d = 0.0f;
|
float sumf = 0.0f;
|
||||||
float sumf_m = 0.0f;
|
float sumf_d8 = 0.0f;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < QI8_1; i0 += QI8_1/2) {
|
for (int i0 = 0; i0 < QR2_K*VDR_Q2_K_Q8_1_MMQ; i0 += QI8_1) {
|
||||||
const float2 dm2f = __half22float2(dm2[i0/(QI8_1/2)]);
|
const float2 dm2f0 = __half22float2(dm2[i0/(QI8_1/2) + 0]);
|
||||||
int sumi_d = 0;
|
int sumi_d0 = 0;
|
||||||
int sumi_m = 0;
|
|
||||||
|
const float2 dm2f1 = __half22float2(dm2[i0/(QI8_1/2) + 1]);
|
||||||
|
int sumi_d1 = 0;
|
||||||
|
|
||||||
const int vi0 = v[i0/(QI8_1/2)];
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
||||||
const int vi = (vi0 >> (2*(i % (QI8_1/2)))) & 0x03030303;
|
sumi_d0 = ggml_cuda_dp4a(v[i], u[i], sumi_d0);
|
||||||
sumi_d = ggml_cuda_dp4a(vi, u[i], sumi_d); // SIMD dot product
|
|
||||||
sumi_m = ggml_cuda_dp4a(0x01010101, u[i], sumi_m);
|
|
||||||
}
|
}
|
||||||
|
sumf_d8 += dm2f0.x * sumi_d0;
|
||||||
|
|
||||||
sumf_d += dm2f.x * sumi_d;
|
#pragma unroll
|
||||||
sumf_m += dm2f.y * sumi_m;
|
for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
|
||||||
|
sumi_d1 = ggml_cuda_dp4a(v[i], u[i], sumi_d1);
|
||||||
|
}
|
||||||
|
sumf_d8 += dm2f1.x * sumi_d1;
|
||||||
|
|
||||||
|
if (i0/QI8_1 < ns8) {
|
||||||
|
const float2 s8f = __half22float2(s8[i0/QI8_1]);
|
||||||
|
sumf -= dm2f0.y*s8f.x;
|
||||||
|
sumf -= dm2f1.y*s8f.y;
|
||||||
|
} else {
|
||||||
|
int sumi_m0 = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
||||||
|
sumi_m0 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m0);
|
||||||
|
}
|
||||||
|
sumf_d8 -= dm2f0.y * sumi_m0;
|
||||||
|
|
||||||
|
int sumi_m1 = 0;
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = i0 + QI8_1/2; i < i0 + QI8_1; ++i) {
|
||||||
|
sumi_m1 = ggml_cuda_dp4a(0x01010101, u[i], sumi_m1);
|
||||||
|
}
|
||||||
|
sumf_d8 -= dm2f1.y * sumi_m1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return d8*(sumf_d - sumf_m);
|
return sumf + d8*sumf_d8;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define VDR_Q3_K_Q8_1_MMVQ 1
|
#define VDR_Q3_K_Q8_1_MMVQ 1
|
||||||
|
@ -283,7 +307,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq(
|
||||||
return d3 * sumf;
|
return d3 * sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
// contiguous u/y values
|
// contiguous v/x + u/y values
|
||||||
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
|
static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
|
||||||
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
|
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales,
|
||||||
const float & d3, const float & d8) {
|
const float & d3, const float & d8) {
|
||||||
|
@ -296,8 +320,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq(
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
for (int i = i0; i < i0 + QI8_1/2; ++i) {
|
||||||
const int vi = __vsubss4((v[i/2] >> (4*(i%2))) & 0x0F0F0F0F, 0x04040404);
|
sumi_sc = ggml_cuda_dp4a(v[i], u[i], sumi_sc); // SIMD dot product
|
||||||
sumi_sc = ggml_cuda_dp4a(vi, u[i], sumi_sc); // SIMD dot product
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sumi += sumi_sc * scales[i0 / (QI8_1/2)];
|
sumi += sumi_sc * scales[i0 / (QI8_1/2)];
|
||||||
|
@ -334,7 +357,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq(
|
||||||
return dm4f.x*sumf_d - dm4f.y*sumf_m;
|
return dm4f.x*sumf_d - dm4f.y*sumf_m;
|
||||||
}
|
}
|
||||||
|
|
||||||
// contiguous u/y values
|
// contiguous v/x + u/y values
|
||||||
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
|
static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq(
|
||||||
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
|
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
|
||||||
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
|
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
|
||||||
|
@ -397,7 +420,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq(
|
||||||
return dm5f.x*sumf_d - dm5f.y*sumf_m;
|
return dm5f.x*sumf_d - dm5f.y*sumf_m;
|
||||||
}
|
}
|
||||||
|
|
||||||
// contiguous u/y values
|
// contiguous v/x + u/y values
|
||||||
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
|
static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq(
|
||||||
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
|
const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc,
|
||||||
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
|
const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) {
|
||||||
|
@ -451,13 +474,16 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
|
||||||
return d*sumf;
|
return d*sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
// contiguous u/y values
|
// contiguous v/x + u/y values
|
||||||
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
|
static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
|
||||||
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
|
const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc,
|
||||||
const float & d6, const float * __restrict__ d8) {
|
const float & d6, const float * __restrict__ d8) {
|
||||||
|
|
||||||
float sumf_d = 0.0f;
|
float sumf_d = 0.0f;
|
||||||
|
|
||||||
|
const int sc_packed = get_int_b4(sc, 0);
|
||||||
|
const int8_t * sc_reg = (const int8_t *) &sc_packed;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
|
for (int i0 = 0; i0 < VDR_Q6_K_Q8_1_MMQ; i0 += 4) {
|
||||||
int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
|
int2 sumi_d = {0, 0}; // 2 q6_K scales per q8_1 scale
|
||||||
|
@ -471,7 +497,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
|
||||||
sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
|
sumi_d.y = ggml_cuda_dp4a(v[2*i+5], u[2*i+5], sumi_d.y); // SIMD dot product
|
||||||
}
|
}
|
||||||
|
|
||||||
sumf_d += d8[i0/4] * (sc[i0/2+0]*sumi_d.x + sc[i0/2+1]*sumi_d.y);
|
sumf_d += d8[i0/4] * (sc_reg[i0/2+0]*sumi_d.x + sc_reg[i0/2+1]*sumi_d.y);
|
||||||
}
|
}
|
||||||
|
|
||||||
return d6 * sumf_d;
|
return d6 * sumf_d;
|
||||||
|
|
|
@ -193,16 +193,16 @@ enum ggml_metal_kernel_type {
|
||||||
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
||||||
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||||
|
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
||||||
|
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
|
||||||
GGML_METAL_KERNEL_TYPE_CONCAT,
|
GGML_METAL_KERNEL_TYPE_CONCAT,
|
||||||
GGML_METAL_KERNEL_TYPE_SQR,
|
GGML_METAL_KERNEL_TYPE_SQR,
|
||||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||||
|
@ -651,14 +651,14 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_IQ4_NL, cpy_f32_iq4_nl, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||||
|
@ -810,8 +810,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
||||||
switch (op->src[0]->type) {
|
switch (op->src[0]->type) {
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
switch (op->type) {
|
switch (op->type) {
|
||||||
case GGML_TYPE_F16:
|
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
|
case GGML_TYPE_F16:
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
|
@ -824,8 +824,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
||||||
}
|
}
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
switch (op->type) {
|
switch (op->type) {
|
||||||
case GGML_TYPE_F16:
|
|
||||||
case GGML_TYPE_F32:
|
case GGML_TYPE_F32:
|
||||||
|
case GGML_TYPE_F16:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -837,7 +837,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
{
|
{
|
||||||
return op->src[0]->type != GGML_TYPE_BF16 && op->ne[3] == 1;
|
return op->ne[3] == 1;
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
@ -1580,8 +1580,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
// some Metal matrix data types require aligned pointers
|
// some Metal matrix data types require aligned pointers
|
||||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||||
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||||
default: break;
|
default: break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2775,8 +2775,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
GGML_ASSERT(ne0 % ggml_blck_size(dst->type) == 0);
|
||||||
|
|
||||||
switch (dstt) {
|
switch (dstt) {
|
||||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
||||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
||||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
||||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0].pipeline; break;
|
||||||
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1].pipeline; break;
|
||||||
|
@ -2789,8 +2789,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
{
|
{
|
||||||
switch (dstt) {
|
switch (dstt) {
|
||||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
|
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
|
||||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F32].pipeline; break;
|
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline; break;
|
||||||
default: GGML_ASSERT(false && "not implemented");
|
default: GGML_ASSERT(false && "not implemented");
|
||||||
};
|
};
|
||||||
} break;
|
} break;
|
||||||
|
|
|
@ -1219,9 +1219,10 @@ kernel void kernel_mul_mv_q8_0_f32(
|
||||||
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define N_F32_F32 4
|
#define N_MV_T_T 4
|
||||||
|
|
||||||
void kernel_mul_mv_f32_f32_impl(
|
template<typename T0, typename T04, typename T1, typename T14>
|
||||||
|
void kernel_mul_mv_impl(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -1239,13 +1240,12 @@ void kernel_mul_mv_f32_f32_impl(
|
||||||
uint64_t nb12,
|
uint64_t nb12,
|
||||||
int64_t ne0,
|
int64_t ne0,
|
||||||
int64_t ne1,
|
int64_t ne1,
|
||||||
uint r2,
|
uint r2,
|
||||||
uint r3,
|
uint r3,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg) {
|
uint tiisg) {
|
||||||
|
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
const int64_t rb = tgpig.y*N_F32_F32;
|
const int64_t rb = tgpig.y*N_MV_T_T;
|
||||||
const int64_t im = tgpig.z;
|
const int64_t im = tgpig.z;
|
||||||
|
|
||||||
const uint i12 = im%ne12;
|
const uint i12 = im%ne12;
|
||||||
|
@ -1253,20 +1253,20 @@ void kernel_mul_mv_f32_f32_impl(
|
||||||
|
|
||||||
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
||||||
|
|
||||||
device const float * x = (device const float *) (src0 + offset0);
|
device const T0 * x = (device const T0 *) (src0 + offset0);
|
||||||
|
|
||||||
if (ne00 < 128) {
|
if (ne00 < 128) {
|
||||||
for (int row = 0; row < N_F32_F32; ++row) {
|
for (int row = 0; row < N_MV_T_T; ++row) {
|
||||||
int r1 = rb + row;
|
int r1 = rb + row;
|
||||||
if (r1 >= ne11) {
|
if (r1 >= ne11) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
for (int i = tiisg; i < ne00; i += 32) {
|
for (int i = tiisg; i < ne00; i += 32) {
|
||||||
sumf += (float) x[i] * (float) y[i];
|
sumf += (T0) x[i] * (T1) y[i];
|
||||||
}
|
}
|
||||||
|
|
||||||
float all_sum = simd_sum(sumf);
|
float all_sum = simd_sum(sumf);
|
||||||
|
@ -1275,32 +1275,32 @@ void kernel_mul_mv_f32_f32_impl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
device const float4 * x4 = (device const float4 *)x;
|
device const T04 * x4 = (device const T04 *) x;
|
||||||
for (int row = 0; row < N_F32_F32; ++row) {
|
for (int row = 0; row < N_MV_T_T; ++row) {
|
||||||
int r1 = rb + row;
|
int r1 = rb + row;
|
||||||
if (r1 >= ne11) {
|
if (r1 >= ne11) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
|
||||||
device const float4 * y4 = (device const float4 *) y;
|
device const T14 * y4 = (device const T14 *) y;
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||||
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
||||||
}
|
}
|
||||||
|
|
||||||
float all_sum = simd_sum(sumf);
|
float all_sum = simd_sum(sumf);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
|
||||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_f32_f32")]]
|
template<typename T0, typename T04, typename T1, typename T14>
|
||||||
kernel void kernel_mul_mv_f32_f32(
|
kernel void kernel_mul_mv(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -1322,90 +1322,38 @@ kernel void kernel_mul_mv_f32_f32(
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
kernel_mul_mv_impl<T0, T04, T1, T14>(
|
||||||
|
src0,
|
||||||
|
src1,
|
||||||
|
dst,
|
||||||
|
ne00,
|
||||||
|
ne01,
|
||||||
|
ne02,
|
||||||
|
nb00,
|
||||||
|
nb01,
|
||||||
|
nb02,
|
||||||
|
ne10,
|
||||||
|
ne11,
|
||||||
|
ne12,
|
||||||
|
nb10,
|
||||||
|
nb11,
|
||||||
|
nb12,
|
||||||
|
ne0,
|
||||||
|
ne1,
|
||||||
|
r2,
|
||||||
|
r3,
|
||||||
|
tgpig,
|
||||||
|
tiisg);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define N_F16_F16 4
|
typedef decltype(kernel_mul_mv<half, half4, half, half4>) mul_mv_t;
|
||||||
|
|
||||||
kernel void kernel_mul_mv_f16_f16(
|
template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t kernel_mul_mv<float, float4, float, float4>;
|
||||||
device const char * src0,
|
template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t kernel_mul_mv<half, half4, float, float4>;
|
||||||
device const char * src1,
|
template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t kernel_mul_mv<half, half4, half, half4>;
|
||||||
device float * dst,
|
|
||||||
constant int64_t & ne00,
|
|
||||||
constant int64_t & ne01,
|
|
||||||
constant int64_t & ne02,
|
|
||||||
constant uint64_t & nb00,
|
|
||||||
constant uint64_t & nb01,
|
|
||||||
constant uint64_t & nb02,
|
|
||||||
constant int64_t & ne10,
|
|
||||||
constant int64_t & ne11,
|
|
||||||
constant int64_t & ne12,
|
|
||||||
constant uint64_t & nb10,
|
|
||||||
constant uint64_t & nb11,
|
|
||||||
constant uint64_t & nb12,
|
|
||||||
constant int64_t & ne0,
|
|
||||||
constant int64_t & ne1,
|
|
||||||
constant uint & r2,
|
|
||||||
constant uint & r3,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
||||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
||||||
|
|
||||||
const int64_t r0 = tgpig.x;
|
template<typename T, typename T4>
|
||||||
const int64_t rb = tgpig.y*N_F16_F16;
|
kernel void kernel_mul_mv_1row(
|
||||||
const int64_t im = tgpig.z;
|
|
||||||
|
|
||||||
const uint i12 = im%ne12;
|
|
||||||
const uint i13 = im/ne12;
|
|
||||||
|
|
||||||
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
||||||
|
|
||||||
device const half * x = (device const half *) (src0 + offset0);
|
|
||||||
|
|
||||||
if (ne00 < 128) {
|
|
||||||
for (int row = 0; row < N_F16_F16; ++row) {
|
|
||||||
int r1 = rb + row;
|
|
||||||
if (r1 >= ne11) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
|
||||||
|
|
||||||
float sumf = 0;
|
|
||||||
for (int i = tiisg; i < ne00; i += 32) {
|
|
||||||
sumf += (half) x[i] * (half) y[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
float all_sum = simd_sum(sumf);
|
|
||||||
if (tiisg == 0) {
|
|
||||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
device const half4 * x4 = (device const half4 *)x;
|
|
||||||
for (int row = 0; row < N_F16_F16; ++row) {
|
|
||||||
int r1 = rb + row;
|
|
||||||
if (r1 >= ne11) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
device const half * y = (device const half *) (src1 + r1*nb11 + im*nb12);
|
|
||||||
device const half4 * y4 = (device const half4 *) y;
|
|
||||||
|
|
||||||
float sumf = 0;
|
|
||||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
||||||
for (int k = 0; k < 4; ++k) sumf += (half) x4[i][k] * y4[i][k];
|
|
||||||
}
|
|
||||||
|
|
||||||
float all_sum = simd_sum(sumf);
|
|
||||||
if (tiisg == 0) {
|
|
||||||
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (half) x[i] * y[i];
|
|
||||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void kernel_mul_mv_f16_f32_1row_impl(
|
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -1437,7 +1385,7 @@ void kernel_mul_mv_f16_f32_1row_impl(
|
||||||
|
|
||||||
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
||||||
|
|
||||||
device const half * x = (device const half *) (src0 + offset0);
|
device const T * x = (device const T *) (src0 + offset0);
|
||||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
|
@ -1450,153 +1398,29 @@ void kernel_mul_mv_f16_f32_1row_impl(
|
||||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
device const half4 * x4 = (device const half4 *) x;
|
device const T4 * x4 = (device const T4 *) x;
|
||||||
device const float4 * y4 = (device const float4 *) y;
|
device const float4 * y4 = (device const float4 *) y;
|
||||||
|
|
||||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||||
for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
|
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
||||||
}
|
}
|
||||||
|
|
||||||
float all_sum = simd_sum(sumf);
|
float all_sum = simd_sum(sumf);
|
||||||
|
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
|
||||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_f16_f32_1row")]]
|
typedef decltype(kernel_mul_mv_1row<half, half4>) mul_mv_1row_t;
|
||||||
kernel void kernel_mul_mv_f16_f32_1row(
|
|
||||||
device const char * src0,
|
|
||||||
device const char * src1,
|
|
||||||
device float * dst,
|
|
||||||
constant int64_t & ne00,
|
|
||||||
constant int64_t & ne01,
|
|
||||||
constant int64_t & ne02,
|
|
||||||
constant uint64_t & nb00,
|
|
||||||
constant uint64_t & nb01,
|
|
||||||
constant uint64_t & nb02,
|
|
||||||
constant int64_t & ne10,
|
|
||||||
constant int64_t & ne11,
|
|
||||||
constant int64_t & ne12,
|
|
||||||
constant uint64_t & nb10,
|
|
||||||
constant uint64_t & nb11,
|
|
||||||
constant uint64_t & nb12,
|
|
||||||
constant int64_t & ne0,
|
|
||||||
constant int64_t & ne1,
|
|
||||||
constant uint & r2,
|
|
||||||
constant uint & r3,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
||||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
||||||
kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
||||||
}
|
|
||||||
|
|
||||||
#define N_F16_F32 4
|
template [[host_name("kernel_mul_mv_f16_f32_1row")]] kernel mul_mv_1row_t kernel_mul_mv_1row<half, half4>;
|
||||||
|
|
||||||
void kernel_mul_mv_f16_f32_impl(
|
|
||||||
device const char * src0,
|
|
||||||
device const char * src1,
|
|
||||||
device float * dst,
|
|
||||||
int64_t ne00,
|
|
||||||
int64_t ne01,
|
|
||||||
int64_t ne02,
|
|
||||||
uint64_t nb00,
|
|
||||||
uint64_t nb01,
|
|
||||||
uint64_t nb02,
|
|
||||||
int64_t ne10,
|
|
||||||
int64_t ne11,
|
|
||||||
int64_t ne12,
|
|
||||||
uint64_t nb10,
|
|
||||||
uint64_t nb11,
|
|
||||||
uint64_t nb12,
|
|
||||||
int64_t ne0,
|
|
||||||
int64_t ne1,
|
|
||||||
uint r2,
|
|
||||||
uint r3,
|
|
||||||
uint3 tgpig,
|
|
||||||
uint tiisg) {
|
|
||||||
|
|
||||||
const int64_t r0 = tgpig.x;
|
|
||||||
const int64_t rb = tgpig.y*N_F16_F32;
|
|
||||||
const int64_t im = tgpig.z;
|
|
||||||
|
|
||||||
const uint i12 = im%ne12;
|
|
||||||
const uint i13 = im/ne12;
|
|
||||||
|
|
||||||
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
|
||||||
|
|
||||||
device const half * x = (device const half *) (src0 + offset0);
|
|
||||||
|
|
||||||
if (ne00 < 128) {
|
|
||||||
for (int row = 0; row < N_F16_F32; ++row) {
|
|
||||||
int r1 = rb + row;
|
|
||||||
if (r1 >= ne11) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
||||||
|
|
||||||
float sumf = 0;
|
|
||||||
for (int i = tiisg; i < ne00; i += 32) {
|
|
||||||
sumf += (float) x[i] * (float) y[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
float all_sum = simd_sum(sumf);
|
|
||||||
if (tiisg == 0) {
|
|
||||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
device const half4 * x4 = (device const half4 *)x;
|
|
||||||
for (int row = 0; row < N_F16_F32; ++row) {
|
|
||||||
int r1 = rb + row;
|
|
||||||
if (r1 >= ne11) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
||||||
device const float4 * y4 = (device const float4 *) y;
|
|
||||||
|
|
||||||
float sumf = 0;
|
|
||||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
||||||
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
|
||||||
}
|
|
||||||
|
|
||||||
float all_sum = simd_sum(sumf);
|
|
||||||
if (tiisg == 0) {
|
|
||||||
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
|
||||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_f16_f32")]]
|
|
||||||
kernel void kernel_mul_mv_f16_f32(
|
|
||||||
device const char * src0,
|
|
||||||
device const char * src1,
|
|
||||||
device float * dst,
|
|
||||||
constant int64_t & ne00,
|
|
||||||
constant int64_t & ne01,
|
|
||||||
constant int64_t & ne02,
|
|
||||||
constant uint64_t & nb00,
|
|
||||||
constant uint64_t & nb01,
|
|
||||||
constant uint64_t & nb02,
|
|
||||||
constant int64_t & ne10,
|
|
||||||
constant int64_t & ne11,
|
|
||||||
constant int64_t & ne12,
|
|
||||||
constant uint64_t & nb10,
|
|
||||||
constant uint64_t & nb11,
|
|
||||||
constant uint64_t & nb12,
|
|
||||||
constant int64_t & ne0,
|
|
||||||
constant int64_t & ne1,
|
|
||||||
constant uint & r2,
|
|
||||||
constant uint & r3,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
||||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
|
||||||
kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assumes row size (ne00) is a multiple of 4
|
// Assumes row size (ne00) is a multiple of 4
|
||||||
kernel void kernel_mul_mv_f16_f32_l4(
|
template<typename T, typename T4>
|
||||||
|
kernel void kernel_mul_mv_l4(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -1628,14 +1452,14 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
||||||
|
|
||||||
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
|
||||||
|
|
||||||
device const half4 * x4 = (device const half4 *) (src0 + offset0);
|
device const T4 * x4 = (device const T4 *) (src0 + offset0);
|
||||||
|
|
||||||
for (int r1 = 0; r1 < nrows; ++r1) {
|
for (int r1 = 0; r1 < nrows; ++r1) {
|
||||||
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
for (int i = tiisg; i < ne00/4; i += 32) {
|
||||||
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
|
||||||
}
|
}
|
||||||
|
|
||||||
float all_sum = simd_sum(sumf);
|
float all_sum = simd_sum(sumf);
|
||||||
|
@ -1645,6 +1469,10 @@ kernel void kernel_mul_mv_f16_f32_l4(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
typedef decltype(kernel_mul_mv_l4<half, half4>) mul_mv_l4_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_f16_f32_l4")]] kernel mul_mv_l4_t kernel_mul_mv_l4<half, half4>;
|
||||||
|
|
||||||
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
static float rope_yarn_ramp(const float low, const float high, const int i0) {
|
||||||
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
const float y = (i0 / 2 - low) / max(0.001f, high - low);
|
||||||
return 1.0f - min(1.0f, max(0.0f, y));
|
return 1.0f - min(1.0f, max(0.0f, y));
|
||||||
|
@ -2765,9 +2593,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
||||||
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<128>;
|
||||||
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
//template [[host_name("kernel_flash_attn_ext_vec_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_vec_f16<256>;
|
||||||
|
|
||||||
kernel void kernel_cpy_f16_f16(
|
template<typename T0, typename T1>
|
||||||
device const half * src0,
|
kernel void kernel_cpy(
|
||||||
device half * dst,
|
device const void * src0,
|
||||||
|
device void * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
|
@ -2798,138 +2627,20 @@ kernel void kernel_cpy_f16_f16(
|
||||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
||||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
||||||
|
|
||||||
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
||||||
|
|
||||||
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
||||||
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
||||||
dst_data[i00] = src[0];
|
dst_data[i00] = (T1) src[0];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_cpy_f16_f32(
|
typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
|
||||||
device const half * src0,
|
|
||||||
device float * dst,
|
|
||||||
constant int64_t & ne00,
|
|
||||||
constant int64_t & ne01,
|
|
||||||
constant int64_t & ne02,
|
|
||||||
constant int64_t & ne03,
|
|
||||||
constant uint64_t & nb00,
|
|
||||||
constant uint64_t & nb01,
|
|
||||||
constant uint64_t & nb02,
|
|
||||||
constant uint64_t & nb03,
|
|
||||||
constant int64_t & ne0,
|
|
||||||
constant int64_t & ne1,
|
|
||||||
constant int64_t & ne2,
|
|
||||||
constant int64_t & ne3,
|
|
||||||
constant uint64_t & nb0,
|
|
||||||
constant uint64_t & nb1,
|
|
||||||
constant uint64_t & nb2,
|
|
||||||
constant uint64_t & nb3,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
|
||||||
const int64_t i03 = tgpig[2];
|
|
||||||
const int64_t i02 = tgpig[1];
|
|
||||||
const int64_t i01 = tgpig[0];
|
|
||||||
|
|
||||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
|
||||||
|
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
|
||||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy<half, half>;
|
||||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy<half, float>;
|
||||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
||||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
||||||
|
|
||||||
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
||||||
|
|
||||||
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
||||||
device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
||||||
dst_data[i00] = src[0];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
kernel void kernel_cpy_f32_f16(
|
|
||||||
device const float * src0,
|
|
||||||
device half * dst,
|
|
||||||
constant int64_t & ne00,
|
|
||||||
constant int64_t & ne01,
|
|
||||||
constant int64_t & ne02,
|
|
||||||
constant int64_t & ne03,
|
|
||||||
constant uint64_t & nb00,
|
|
||||||
constant uint64_t & nb01,
|
|
||||||
constant uint64_t & nb02,
|
|
||||||
constant uint64_t & nb03,
|
|
||||||
constant int64_t & ne0,
|
|
||||||
constant int64_t & ne1,
|
|
||||||
constant int64_t & ne2,
|
|
||||||
constant int64_t & ne3,
|
|
||||||
constant uint64_t & nb0,
|
|
||||||
constant uint64_t & nb1,
|
|
||||||
constant uint64_t & nb2,
|
|
||||||
constant uint64_t & nb3,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
|
||||||
const int64_t i03 = tgpig[2];
|
|
||||||
const int64_t i02 = tgpig[1];
|
|
||||||
const int64_t i01 = tgpig[0];
|
|
||||||
|
|
||||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
||||||
|
|
||||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
||||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
||||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
||||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
||||||
|
|
||||||
device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
||||||
|
|
||||||
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
||||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
||||||
|
|
||||||
dst_data[i00] = src[0];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
kernel void kernel_cpy_f32_f32(
|
|
||||||
device const float * src0,
|
|
||||||
device float * dst,
|
|
||||||
constant int64_t & ne00,
|
|
||||||
constant int64_t & ne01,
|
|
||||||
constant int64_t & ne02,
|
|
||||||
constant int64_t & ne03,
|
|
||||||
constant uint64_t & nb00,
|
|
||||||
constant uint64_t & nb01,
|
|
||||||
constant uint64_t & nb02,
|
|
||||||
constant uint64_t & nb03,
|
|
||||||
constant int64_t & ne0,
|
|
||||||
constant int64_t & ne1,
|
|
||||||
constant int64_t & ne2,
|
|
||||||
constant int64_t & ne3,
|
|
||||||
constant uint64_t & nb0,
|
|
||||||
constant uint64_t & nb1,
|
|
||||||
constant uint64_t & nb2,
|
|
||||||
constant uint64_t & nb3,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]],
|
|
||||||
uint3 ntg[[threads_per_threadgroup]]) {
|
|
||||||
const int64_t i03 = tgpig[2];
|
|
||||||
const int64_t i02 = tgpig[1];
|
|
||||||
const int64_t i01 = tgpig[0];
|
|
||||||
|
|
||||||
const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
|
||||||
|
|
||||||
const int64_t i3 = n / (ne2*ne1*ne0);
|
|
||||||
const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
|
|
||||||
const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
|
|
||||||
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
|
|
||||||
|
|
||||||
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
|
|
||||||
|
|
||||||
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
|
|
||||||
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
|
|
||||||
|
|
||||||
dst_data[i00] = src[0];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
kernel void kernel_cpy_f32_q8_0(
|
kernel void kernel_cpy_f32_q8_0(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
|
@ -5730,9 +5441,9 @@ void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||||
kernel void kernel_get_rows(
|
kernel void kernel_get_rows_q(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const char * src1,
|
device const void * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant uint64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
|
@ -5745,27 +5456,24 @@ kernel void kernel_get_rows(
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
uint3 tptg [[threads_per_threadgroup]]) {
|
uint3 tptg [[threads_per_threadgroup]]) {
|
||||||
//const int64_t i = tgpig;
|
|
||||||
//const int64_t r = ((device int32_t *) src1)[i];
|
|
||||||
|
|
||||||
const int64_t i10 = tgpig.x;
|
const int64_t i10 = tgpig.x;
|
||||||
const int64_t i11 = tgpig.y;
|
const int64_t i11 = tgpig.y;
|
||||||
|
|
||||||
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||||
|
|
||||||
const int64_t i02 = i11;
|
const int64_t i02 = i11;
|
||||||
|
|
||||||
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
|
||||||
float4x4 temp;
|
float4x4 temp;
|
||||||
dequantize_func(
|
dequantize_func(((device const block_q *) ((const device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
||||||
((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
|
|
||||||
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
*(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_get_rows_f32(
|
template<typename T>
|
||||||
|
kernel void kernel_get_rows_f(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const char * src1,
|
device const void * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant uint64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
|
@ -5781,47 +5489,19 @@ kernel void kernel_get_rows_f32(
|
||||||
const int64_t i10 = tgpig.x;
|
const int64_t i10 = tgpig.x;
|
||||||
const int64_t i11 = tgpig.y;
|
const int64_t i11 = tgpig.y;
|
||||||
|
|
||||||
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||||
|
|
||||||
const int64_t i02 = i11;
|
const int64_t i02 = i11;
|
||||||
|
|
||||||
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
||||||
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
(( device float *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
||||||
((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
((const device T *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
kernel void kernel_get_rows_f16(
|
|
||||||
device const void * src0,
|
|
||||||
device const char * src1,
|
|
||||||
device float * dst,
|
|
||||||
constant int64_t & ne00,
|
|
||||||
constant uint64_t & nb01,
|
|
||||||
constant uint64_t & nb02,
|
|
||||||
constant int64_t & ne10,
|
|
||||||
constant uint64_t & nb10,
|
|
||||||
constant uint64_t & nb11,
|
|
||||||
constant uint64_t & nb1,
|
|
||||||
constant uint64_t & nb2,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
|
||||||
uint3 tptg [[threads_per_threadgroup]]) {
|
|
||||||
const int64_t i10 = tgpig.x;
|
|
||||||
const int64_t i11 = tgpig.y;
|
|
||||||
|
|
||||||
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
|
||||||
|
|
||||||
const int64_t i02 = i11;
|
|
||||||
|
|
||||||
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
|
||||||
((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
|
||||||
((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_get_rows_i32(
|
kernel void kernel_get_rows_i32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const char * src1,
|
device const void * src1,
|
||||||
device int32_t * dst,
|
device int32_t * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant uint64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
|
@ -5837,13 +5517,13 @@ kernel void kernel_get_rows_i32(
|
||||||
const int64_t i10 = tgpig.x;
|
const int64_t i10 = tgpig.x;
|
||||||
const int64_t i11 = tgpig.y;
|
const int64_t i11 = tgpig.y;
|
||||||
|
|
||||||
const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
|
const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*nb11 + i10*nb10))[0];
|
||||||
|
|
||||||
const int64_t i02 = i11;
|
const int64_t i02 = i11;
|
||||||
|
|
||||||
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
for (int ind = tiitg; ind < ne00; ind += tptg.x) {
|
||||||
((device int32_t *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
(( device int32_t *) (( device char *) dst + i11*nb2 + i10*nb1))[ind] =
|
||||||
((device int32_t *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
|
((const device int32_t *) ((const device char *) src0 + i02*nb02 + r*nb01))[ind];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -5860,28 +5540,28 @@ kernel void kernel_get_rows_i32(
|
||||||
#define SG_MAT_ROW 8
|
#define SG_MAT_ROW 8
|
||||||
|
|
||||||
// each block_q contains 16*nl weights
|
// each block_q contains 16*nl weights
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
||||||
void kernel_mul_mm_impl(device const uchar * src0,
|
kernel void kernel_mul_mm(device const uchar * src0,
|
||||||
device const uchar * src1,
|
device const uchar * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
constant uint64_t & nb01,
|
constant uint64_t & nb01,
|
||||||
constant uint64_t & nb02,
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne12,
|
constant int64_t & ne12,
|
||||||
constant uint64_t & nb10,
|
constant uint64_t & nb10,
|
||||||
constant uint64_t & nb11,
|
constant uint64_t & nb11,
|
||||||
constant uint64_t & nb12,
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
threadgroup T * sa = (threadgroup T *)(shared_memory);
|
||||||
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
||||||
|
|
||||||
const uint r0 = tgpig.y;
|
const uint r0 = tgpig.y;
|
||||||
|
@ -5896,7 +5576,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
||||||
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||||
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||||
|
|
||||||
simdgroup_half8x8 ma[4];
|
simdgroup_T8x8 ma[4];
|
||||||
simdgroup_float8x8 mb[2];
|
simdgroup_float8x8 mb[2];
|
||||||
simdgroup_float8x8 c_res[8];
|
simdgroup_float8x8 c_res[8];
|
||||||
for (int i = 0; i < 8; i++){
|
for (int i = 0; i < 8; i++){
|
||||||
|
@ -5919,7 +5599,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
||||||
|
|
||||||
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
||||||
// load data and store to threadgroup memory
|
// load data and store to threadgroup memory
|
||||||
half4x4 temp_a;
|
T4x4 temp_a;
|
||||||
dequantize_func(x, il, temp_a);
|
dequantize_func(x, il, temp_a);
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
@ -5939,7 +5619,7 @@ void kernel_mul_mm_impl(device const uchar * src0,
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
// load matrices from threadgroup memory and conduct outer products
|
// load matrices from threadgroup memory and conduct outer products
|
||||||
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
threadgroup T * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
|
||||||
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
|
||||||
|
|
||||||
#pragma unroll(4)
|
#pragma unroll(4)
|
||||||
|
@ -6115,48 +5795,6 @@ void kernel_mul_mm_id_impl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
|
||||||
kernel void kernel_mul_mm(device const uchar * src0,
|
|
||||||
device const uchar * src1,
|
|
||||||
device float * dst,
|
|
||||||
constant int64_t & ne00,
|
|
||||||
constant int64_t & ne02,
|
|
||||||
constant uint64_t & nb01,
|
|
||||||
constant uint64_t & nb02,
|
|
||||||
constant int64_t & ne12,
|
|
||||||
constant uint64_t & nb10,
|
|
||||||
constant uint64_t & nb11,
|
|
||||||
constant uint64_t & nb12,
|
|
||||||
constant int64_t & ne0,
|
|
||||||
constant int64_t & ne1,
|
|
||||||
constant uint & r2,
|
|
||||||
constant uint & r3,
|
|
||||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
|
||||||
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
|
||||||
src0,
|
|
||||||
src1,
|
|
||||||
dst,
|
|
||||||
ne00,
|
|
||||||
ne02,
|
|
||||||
nb01,
|
|
||||||
nb02,
|
|
||||||
ne12,
|
|
||||||
nb10,
|
|
||||||
nb11,
|
|
||||||
nb12,
|
|
||||||
ne0,
|
|
||||||
ne1,
|
|
||||||
r2,
|
|
||||||
r3,
|
|
||||||
shared_memory,
|
|
||||||
tgpig,
|
|
||||||
tiitg,
|
|
||||||
sgitg);
|
|
||||||
}
|
|
||||||
|
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||||
kernel void kernel_mul_mm_id(
|
kernel void kernel_mul_mm_id(
|
||||||
device const uchar * src0s,
|
device const uchar * src0s,
|
||||||
|
@ -6237,69 +5875,60 @@ kernel void kernel_mul_mm_id(
|
||||||
// get rows
|
// get rows
|
||||||
//
|
//
|
||||||
|
|
||||||
typedef void (get_rows_t)(
|
typedef decltype(kernel_get_rows_f<float>) get_rows_f_t;
|
||||||
device const void * src0,
|
|
||||||
device const char * src1,
|
|
||||||
device float * dst,
|
|
||||||
constant int64_t & ne00,
|
|
||||||
constant uint64_t & nb01,
|
|
||||||
constant uint64_t & nb02,
|
|
||||||
constant int64_t & ne10,
|
|
||||||
constant uint64_t & nb10,
|
|
||||||
constant uint64_t & nb11,
|
|
||||||
constant uint64_t & nb1,
|
|
||||||
constant uint64_t & nb2,
|
|
||||||
uint3, uint, uint3);
|
|
||||||
|
|
||||||
//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f<float>;
|
||||||
//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f<half>;
|
||||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
|
||||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
typedef decltype(kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>) get_rows_q_t;
|
||||||
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
|
|
||||||
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
|
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_0, 2, dequantize_q4_0>;
|
||||||
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
|
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_1, 2, dequantize_q4_1>;
|
||||||
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
|
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_0, 2, dequantize_q5_0>;
|
||||||
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
|
template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_1, 2, dequantize_q5_1>;
|
||||||
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
|
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_q_t kernel_get_rows_q<block_q8_0, 2, dequantize_q8_0>;
|
||||||
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||||
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||||
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_t kernel_get_rows<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||||
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_t kernel_get_rows<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||||
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_rows<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_q_t kernel_get_rows_q<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||||
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
template [[host_name("kernel_get_rows_iq2_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||||
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
template [[host_name("kernel_get_rows_iq2_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||||
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||||
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||||
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||||
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||||
|
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_q_t kernel_get_rows_q<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||||
|
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||||
|
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||||
|
|
||||||
//
|
//
|
||||||
// matrix-matrix multiplication
|
// matrix-matrix multiplication
|
||||||
//
|
//
|
||||||
|
|
||||||
typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
|
typedef decltype(kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>) mat_mm_t;
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
|
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1, dequantize_f32>;
|
||||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1, dequantize_f16>;
|
||||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0>;
|
||||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1>;
|
||||||
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2, dequantize_q5_0>;
|
template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0>;
|
||||||
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2, dequantize_q5_1>;
|
template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2, dequantize_q5_1>;
|
||||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2, dequantize_q8_0>;
|
||||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
|
||||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
|
||||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
|
||||||
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
|
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
|
||||||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
|
||||||
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
template [[host_name("kernel_mul_mm_iq2_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
|
||||||
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
template [[host_name("kernel_mul_mm_iq2_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
|
||||||
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
|
||||||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||||
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
|
||||||
|
|
||||||
//
|
//
|
||||||
// indirect matrix-matrix multiplication
|
// indirect matrix-matrix multiplication
|
||||||
|
@ -6436,7 +6065,7 @@ void mmv_fn(
|
||||||
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
|
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t;
|
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
|
||||||
|
|
||||||
template<mul_mv_impl_fn_t impl_fn>
|
template<mul_mv_impl_fn_t impl_fn>
|
||||||
kernel void kernel_mul_mv_id(
|
kernel void kernel_mul_mv_id(
|
||||||
|
@ -6514,20 +6143,20 @@ kernel void kernel_mul_mv_id(
|
||||||
sgitg);
|
sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t;
|
typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>) kernel_mul_mv_id_t;
|
||||||
|
|
||||||
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
|
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<float, float4, float, float4>>>;
|
||||||
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
|
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_impl<half, half4, float, float4>>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q8_0_f32_impl>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
template [[host_name("kernel_mul_mv_id_q4_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
template [[host_name("kernel_mul_mv_id_q4_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
template [[host_name("kernel_mul_mv_id_q5_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
template [[host_name("kernel_mul_mv_id_q5_1_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
|
template [[host_name("kernel_mul_mv_id_q2_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q2_K_f32_impl>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
|
template [[host_name("kernel_mul_mv_id_q3_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q3_K_f32_impl>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
|
template [[host_name("kernel_mul_mv_id_q4_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q4_K_f32_impl>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
|
template [[host_name("kernel_mul_mv_id_q5_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q5_K_f32_impl>>;
|
||||||
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
|
template [[host_name("kernel_mul_mv_id_q6_K_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_q6_K_f32_impl>>;
|
||||||
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
|
template [[host_name("kernel_mul_mv_id_iq1_s_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_s_f32_impl>>;
|
||||||
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
|
template [[host_name("kernel_mul_mv_id_iq1_m_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq1_m_f32_impl>>;
|
||||||
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
|
template [[host_name("kernel_mul_mv_id_iq2_xxs_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_iq2_xxs_f32_impl>>;
|
||||||
|
|
|
@ -659,7 +659,7 @@ static inline __m128i packNibbles( __m256i bytes ) {
|
||||||
#endif //__loongarch_asx
|
#endif //__loongarch_asx
|
||||||
|
|
||||||
// reference implementation for deterministic creation of model files
|
// reference implementation for deterministic creation of model files
|
||||||
void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
|
void quantize_row_q4_0_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) {
|
||||||
static const int qk = QK4_0;
|
static const int qk = QK4_0;
|
||||||
|
|
||||||
assert(k % qk == 0);
|
assert(k % qk == 0);
|
||||||
|
@ -697,11 +697,11 @@ void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) {
|
void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) {
|
||||||
quantize_row_q4_0_reference(x, y, k);
|
quantize_row_q4_0_ref(x, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict y, int64_t k) {
|
void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, int64_t k) {
|
||||||
const int qk = QK4_1;
|
const int qk = QK4_1;
|
||||||
|
|
||||||
assert(k % qk == 0);
|
assert(k % qk == 0);
|
||||||
|
@ -739,10 +739,10 @@ void quantize_row_q4_1_reference(const float * restrict x, block_q4_1 * restrict
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) {
|
void quantize_row_q4_1(const float * restrict x, void * restrict y, int64_t k) {
|
||||||
quantize_row_q4_1_reference(x, y, k);
|
quantize_row_q4_1_ref(x, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict y, int64_t k) {
|
void quantize_row_q5_0_ref(const float * restrict x, block_q5_0 * restrict y, int64_t k) {
|
||||||
static const int qk = QK5_0;
|
static const int qk = QK5_0;
|
||||||
|
|
||||||
assert(k % qk == 0);
|
assert(k % qk == 0);
|
||||||
|
@ -787,10 +787,10 @@ void quantize_row_q5_0_reference(const float * restrict x, block_q5_0 * restrict
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) {
|
void quantize_row_q5_0(const float * restrict x, void * restrict y, int64_t k) {
|
||||||
quantize_row_q5_0_reference(x, y, k);
|
quantize_row_q5_0_ref(x, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict y, int64_t k) {
|
void quantize_row_q5_1_ref(const float * restrict x, block_q5_1 * restrict y, int64_t k) {
|
||||||
const int qk = QK5_1;
|
const int qk = QK5_1;
|
||||||
|
|
||||||
assert(k % qk == 0);
|
assert(k % qk == 0);
|
||||||
|
@ -835,11 +835,11 @@ void quantize_row_q5_1_reference(const float * restrict x, block_q5_1 * restrict
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) {
|
void quantize_row_q5_1(const float * restrict x, void * restrict y, int64_t k) {
|
||||||
quantize_row_q5_1_reference(x, y, k);
|
quantize_row_q5_1_ref(x, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
// reference implementation for deterministic creation of model files
|
// reference implementation for deterministic creation of model files
|
||||||
void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * restrict y, int64_t k) {
|
void quantize_row_q8_0_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) {
|
||||||
assert(k % QK8_0 == 0);
|
assert(k % QK8_0 == 0);
|
||||||
const int nb = k / QK8_0;
|
const int nb = k / QK8_0;
|
||||||
|
|
||||||
|
@ -1145,12 +1145,12 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(nb);
|
GGML_UNUSED(nb);
|
||||||
// scalar
|
// scalar
|
||||||
quantize_row_q8_0_reference(x, y, k);
|
quantize_row_q8_0_ref(x, y, k);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// reference implementation for deterministic creation of model files
|
// reference implementation for deterministic creation of model files
|
||||||
void quantize_row_q8_1_reference(const float * restrict x, block_q8_1 * restrict y, int64_t k) {
|
void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, int64_t k) {
|
||||||
assert(QK8_1 == 32);
|
assert(QK8_1 == 32);
|
||||||
assert(k % QK8_1 == 0);
|
assert(k % QK8_1 == 0);
|
||||||
const int nb = k / QK8_1;
|
const int nb = k / QK8_1;
|
||||||
|
@ -1509,7 +1509,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
|
||||||
#else
|
#else
|
||||||
GGML_UNUSED(nb);
|
GGML_UNUSED(nb);
|
||||||
// scalar
|
// scalar
|
||||||
quantize_row_q8_1_reference(x, y, k);
|
quantize_row_q8_1_ref(x, y, k);
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1900,7 +1900,7 @@ static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t *
|
||||||
|
|
||||||
//========================- 2-bit (de)-quantization
|
//========================- 2-bit (de)-quantization
|
||||||
|
|
||||||
void quantize_row_q2_K_reference(const float * restrict x, block_q2_K * restrict y, int64_t k) {
|
void quantize_row_q2_K_ref(const float * restrict x, block_q2_K * restrict y, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
|
@ -2003,7 +2003,7 @@ void dequantize_row_q2_K(const block_q2_K * restrict x, float * restrict y, int6
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) {
|
void quantize_row_q2_K(const float * restrict x, void * restrict vy, int64_t k) {
|
||||||
quantize_row_q2_K_reference(x, vy, k);
|
quantize_row_q2_K_ref(x, vy, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
|
static float make_qkx3_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
|
||||||
|
@ -2227,7 +2227,7 @@ static void quantize_row_q2_K_impl(const float * restrict x, block_q2_K * restri
|
||||||
size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||||
size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
|
size_t row_size = ggml_row_size(GGML_TYPE_Q2_K, n_per_row);
|
||||||
if (!quant_weights) {
|
if (!quant_weights) {
|
||||||
quantize_row_q2_K_reference(src, dst, (int64_t)nrow*n_per_row);
|
quantize_row_q2_K_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
char * qrow = (char *)dst;
|
char * qrow = (char *)dst;
|
||||||
|
@ -2242,7 +2242,7 @@ size_t quantize_q2_K(const float * restrict src, void * restrict dst, int64_t nr
|
||||||
|
|
||||||
//========================= 3-bit (de)-quantization
|
//========================= 3-bit (de)-quantization
|
||||||
|
|
||||||
void quantize_row_q3_K_reference(const float * restrict x, block_q3_K * restrict y, int64_t k) {
|
void quantize_row_q3_K_ref(const float * restrict x, block_q3_K * restrict y, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
|
@ -2369,7 +2369,7 @@ void dequantize_row_q3_K(const block_q3_K * restrict x, float * restrict y, int6
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) {
|
void quantize_row_q3_K(const float * restrict x, void * restrict vy, int64_t k) {
|
||||||
quantize_row_q3_K_reference(x, vy, k);
|
quantize_row_q3_K_ref(x, vy, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) {
|
static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restrict y, int64_t n_per_row, const float * restrict quant_weights) {
|
||||||
|
@ -2459,7 +2459,7 @@ static void quantize_row_q3_K_impl(const float * restrict x, block_q3_K * restri
|
||||||
size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||||
size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
|
size_t row_size = ggml_row_size(GGML_TYPE_Q3_K, n_per_row);
|
||||||
if (!quant_weights) {
|
if (!quant_weights) {
|
||||||
quantize_row_q3_K_reference(src, dst, (int64_t)nrow*n_per_row);
|
quantize_row_q3_K_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
char * qrow = (char *)dst;
|
char * qrow = (char *)dst;
|
||||||
|
@ -2474,7 +2474,7 @@ size_t quantize_q3_K(const float * restrict src, void * restrict dst, int64_t nr
|
||||||
|
|
||||||
// ====================== 4-bit (de)-quantization
|
// ====================== 4-bit (de)-quantization
|
||||||
|
|
||||||
void quantize_row_q4_K_reference(const float * restrict x, block_q4_K * restrict y, int64_t k) {
|
void quantize_row_q4_K_ref(const float * restrict x, block_q4_K * restrict y, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
const int nb = k / QK_K;
|
const int nb = k / QK_K;
|
||||||
|
|
||||||
|
@ -2573,7 +2573,7 @@ void dequantize_row_q4_K(const block_q4_K * restrict x, float * restrict y, int6
|
||||||
void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) {
|
void quantize_row_q4_K(const float * restrict x, void * restrict vy, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
block_q4_K * restrict y = vy;
|
block_q4_K * restrict y = vy;
|
||||||
quantize_row_q4_K_reference(x, y, k);
|
quantize_row_q4_K_ref(x, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) {
|
static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restrict y, int64_t n_per_row, const float * quant_weights) {
|
||||||
|
@ -2652,7 +2652,7 @@ static void quantize_row_q4_K_impl(const float * restrict x, block_q4_K * restri
|
||||||
size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||||
size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
|
size_t row_size = ggml_row_size(GGML_TYPE_Q4_K, n_per_row);
|
||||||
if (!quant_weights) {
|
if (!quant_weights) {
|
||||||
quantize_row_q4_K_reference(src, dst, (int64_t)nrow*n_per_row);
|
quantize_row_q4_K_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
char * qrow = (char *)dst;
|
char * qrow = (char *)dst;
|
||||||
|
@ -2667,7 +2667,7 @@ size_t quantize_q4_K(const float * restrict src, void * restrict dst, int64_t nr
|
||||||
|
|
||||||
// ====================== 5-bit (de)-quantization
|
// ====================== 5-bit (de)-quantization
|
||||||
|
|
||||||
void quantize_row_q5_K_reference(const float * restrict x, block_q5_K * restrict y, int64_t k) {
|
void quantize_row_q5_K_ref(const float * restrict x, block_q5_K * restrict y, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
const int64_t nb = k / QK_K;
|
const int64_t nb = k / QK_K;
|
||||||
|
|
||||||
|
@ -2784,7 +2784,7 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int6
|
||||||
void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) {
|
void quantize_row_q5_K(const float * restrict x, void * restrict vy, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
block_q5_K * restrict y = vy;
|
block_q5_K * restrict y = vy;
|
||||||
quantize_row_q5_K_reference(x, y, k);
|
quantize_row_q5_K_ref(x, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) {
|
static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restrict y, int64_t n_per_row, const float * quant_weights) {
|
||||||
|
@ -2883,7 +2883,7 @@ static void quantize_row_q5_K_impl(const float * restrict x, block_q5_K * restri
|
||||||
size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||||
size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
|
size_t row_size = ggml_row_size(GGML_TYPE_Q5_K, n_per_row);
|
||||||
if (!quant_weights) {
|
if (!quant_weights) {
|
||||||
quantize_row_q5_K_reference(src, dst, (int64_t)nrow*n_per_row);
|
quantize_row_q5_K_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
char * qrow = (char *)dst;
|
char * qrow = (char *)dst;
|
||||||
|
@ -2898,7 +2898,7 @@ size_t quantize_q5_K(const float * restrict src, void * restrict dst, int64_t nr
|
||||||
|
|
||||||
// ====================== 6-bit (de)-quantization
|
// ====================== 6-bit (de)-quantization
|
||||||
|
|
||||||
void quantize_row_q6_K_reference(const float * restrict x, block_q6_K * restrict y, int64_t k) {
|
void quantize_row_q6_K_ref(const float * restrict x, block_q6_K * restrict y, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
const int64_t nb = k / QK_K;
|
const int64_t nb = k / QK_K;
|
||||||
|
|
||||||
|
@ -3002,7 +3002,7 @@ void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int6
|
||||||
void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) {
|
void quantize_row_q6_K(const float * restrict x, void * restrict vy, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
block_q6_K * restrict y = vy;
|
block_q6_K * restrict y = vy;
|
||||||
quantize_row_q6_K_reference(x, y, k);
|
quantize_row_q6_K_ref(x, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) {
|
static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restrict y, int64_t n_per_row, const float * quant_weights) {
|
||||||
|
@ -3092,7 +3092,7 @@ static void quantize_row_q6_K_impl(const float * restrict x, block_q6_K * restri
|
||||||
size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
size_t quantize_q6_K(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||||
size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
|
size_t row_size = ggml_row_size(GGML_TYPE_Q6_K, n_per_row);
|
||||||
if (!quant_weights) {
|
if (!quant_weights) {
|
||||||
quantize_row_q6_K_reference(src, dst, (int64_t)nrow*n_per_row);
|
quantize_row_q6_K_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
char * qrow = (char *)dst;
|
char * qrow = (char *)dst;
|
||||||
|
@ -3109,7 +3109,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
|
||||||
static_assert(QK4_0 == 32, "QK4_0 must be 32");
|
static_assert(QK4_0 == 32, "QK4_0 must be 32");
|
||||||
|
|
||||||
if (!quant_weights) {
|
if (!quant_weights) {
|
||||||
quantize_row_q4_0_reference(x, y, n_per_row);
|
quantize_row_q4_0_ref(x, y, n_per_row);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3135,7 +3135,7 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri
|
||||||
|
|
||||||
size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||||
if (!quant_weights) {
|
if (!quant_weights) {
|
||||||
quantize_row_q4_0_reference(src, dst, (int64_t)nrow*n_per_row);
|
quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||||
return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
|
return nrow * ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
|
||||||
}
|
}
|
||||||
size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
|
size_t row_size = ggml_row_size(GGML_TYPE_Q4_0, n_per_row);
|
||||||
|
@ -3152,7 +3152,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri
|
||||||
static_assert(QK4_1 == 32, "QK4_1 must be 32");
|
static_assert(QK4_1 == 32, "QK4_1 must be 32");
|
||||||
|
|
||||||
if (!quant_weights) {
|
if (!quant_weights) {
|
||||||
quantize_row_q4_1_reference(x, y, n_per_row);
|
quantize_row_q4_1_ref(x, y, n_per_row);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3180,7 +3180,7 @@ static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restri
|
||||||
|
|
||||||
size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
size_t quantize_q4_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||||
if (!quant_weights) {
|
if (!quant_weights) {
|
||||||
quantize_row_q4_1_reference(src, dst, (int64_t)nrow*n_per_row);
|
quantize_row_q4_1_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||||
return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
|
return nrow * ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
|
||||||
}
|
}
|
||||||
size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
|
size_t row_size = ggml_row_size(GGML_TYPE_Q4_1, n_per_row);
|
||||||
|
@ -3197,7 +3197,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
|
||||||
static_assert(QK5_0 == 32, "QK5_0 must be 32");
|
static_assert(QK5_0 == 32, "QK5_0 must be 32");
|
||||||
|
|
||||||
if (!quant_weights) {
|
if (!quant_weights) {
|
||||||
quantize_row_q5_0_reference(x, y, n_per_row);
|
quantize_row_q5_0_ref(x, y, n_per_row);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3234,7 +3234,7 @@ static void quantize_row_q5_0_impl(const float * restrict x, block_q5_0 * restri
|
||||||
|
|
||||||
size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
size_t quantize_q5_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||||
if (!quant_weights) {
|
if (!quant_weights) {
|
||||||
quantize_row_q5_0_reference(src, dst, (int64_t)nrow*n_per_row);
|
quantize_row_q5_0_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||||
return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
|
return nrow * ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
|
||||||
}
|
}
|
||||||
size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
|
size_t row_size = ggml_row_size(GGML_TYPE_Q5_0, n_per_row);
|
||||||
|
@ -3251,7 +3251,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri
|
||||||
static_assert(QK5_1 == 32, "QK5_1 must be 32");
|
static_assert(QK5_1 == 32, "QK5_1 must be 32");
|
||||||
|
|
||||||
if (!quant_weights) {
|
if (!quant_weights) {
|
||||||
quantize_row_q5_1_reference(x, y, n_per_row);
|
quantize_row_q5_1_ref(x, y, n_per_row);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3287,7 +3287,7 @@ static void quantize_row_q5_1_impl(const float * restrict x, block_q5_1 * restri
|
||||||
|
|
||||||
size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||||
if (!quant_weights) {
|
if (!quant_weights) {
|
||||||
quantize_row_q5_1_reference(src, dst, (int64_t)nrow*n_per_row);
|
quantize_row_q5_1_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||||
return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
|
return nrow * ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
|
||||||
}
|
}
|
||||||
size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
|
size_t row_size = ggml_row_size(GGML_TYPE_Q5_1, n_per_row);
|
||||||
|
@ -3303,7 +3303,7 @@ size_t quantize_q5_1(const float * restrict src, void * restrict dst, int64_t nr
|
||||||
size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) {
|
||||||
(void)quant_weights; // not used
|
(void)quant_weights; // not used
|
||||||
const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row);
|
const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0, n_per_row);
|
||||||
quantize_row_q8_0_reference(src, dst, (int64_t)nrow*n_per_row);
|
quantize_row_q8_0_ref(src, dst, (int64_t)nrow*n_per_row);
|
||||||
return nrow * row_size;
|
return nrow * row_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -3591,7 +3591,7 @@ void dequantize_row_iq4_xs(const block_iq4_xs * restrict x, float * restrict y,
|
||||||
|
|
||||||
//===================================== Q8_K ==============================================
|
//===================================== Q8_K ==============================================
|
||||||
|
|
||||||
void quantize_row_q8_K_reference(const float * restrict x, block_q8_K * restrict y, int64_t k) {
|
void quantize_row_q8_K_ref(const float * restrict x, block_q8_K * restrict y, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
const int64_t nb = k / QK_K;
|
const int64_t nb = k / QK_K;
|
||||||
|
|
||||||
|
@ -3642,7 +3642,7 @@ void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int6
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
|
void quantize_row_q8_K(const float * restrict x, void * restrict y, int64_t k) {
|
||||||
quantize_row_q8_K_reference(x, y, k);
|
quantize_row_q8_K_ref(x, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===================================== Dot ptoducts =================================
|
//===================================== Dot ptoducts =================================
|
||||||
|
@ -13531,10 +13531,10 @@ size_t quantize_iq3_xxs(const float * restrict src, void * restrict dst, int64_t
|
||||||
void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) {
|
void quantize_row_iq3_xxs(const float * restrict x, void * restrict vy, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
block_iq3_xxs * restrict y = vy;
|
block_iq3_xxs * restrict y = vy;
|
||||||
quantize_row_iq3_xxs_reference(x, y, k);
|
quantize_row_iq3_xxs_ref(x, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_row_iq3_xxs_reference(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) {
|
void quantize_row_iq3_xxs_ref(const float * restrict x, block_iq3_xxs * restrict y, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
quantize_row_iq3_xxs_impl(256, x, y, k, NULL);
|
quantize_row_iq3_xxs_impl(256, x, y, k, NULL);
|
||||||
}
|
}
|
||||||
|
@ -13747,10 +13747,10 @@ size_t quantize_iq3_s(const float * restrict src, void * restrict dst, int64_t n
|
||||||
void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) {
|
void quantize_row_iq3_s(const float * restrict x, void * restrict vy, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
block_iq3_s * restrict y = vy;
|
block_iq3_s * restrict y = vy;
|
||||||
quantize_row_iq3_s_reference(x, y, k);
|
quantize_row_iq3_s_ref(x, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_row_iq3_s_reference(const float * restrict x, block_iq3_s * restrict y, int64_t k) {
|
void quantize_row_iq3_s_ref(const float * restrict x, block_iq3_s * restrict y, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
quantize_iq3_s(x, y, 1, k, NULL);
|
quantize_iq3_s(x, y, 1, k, NULL);
|
||||||
}
|
}
|
||||||
|
@ -14488,7 +14488,7 @@ void quantize_row_iq4_nl(const float * restrict x, void * restrict vy, int64_t k
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_row_iq4_nl_reference(const float * restrict x, block_iq4_nl * restrict y, int64_t k) {
|
void quantize_row_iq4_nl_ref(const float * restrict x, block_iq4_nl * restrict y, int64_t k) {
|
||||||
assert(k % QK4_NL == 0);
|
assert(k % QK4_NL == 0);
|
||||||
quantize_row_iq4_nl(x, y, k);
|
quantize_row_iq4_nl(x, y, k);
|
||||||
}
|
}
|
||||||
|
@ -14516,10 +14516,10 @@ size_t quantize_iq4_xs(const float * restrict src, void * restrict dst, int64_t
|
||||||
void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) {
|
void quantize_row_iq4_xs(const float * restrict x, void * restrict vy, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
block_iq4_xs * restrict y = vy;
|
block_iq4_xs * restrict y = vy;
|
||||||
quantize_row_iq4_xs_reference(x, y, k);
|
quantize_row_iq4_xs_ref(x, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_row_iq4_xs_reference(const float * restrict x, block_iq4_xs * restrict y, int64_t k) {
|
void quantize_row_iq4_xs_ref(const float * restrict x, block_iq4_xs * restrict y, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
quantize_iq4_xs(x, y, 1, k, NULL);
|
quantize_iq4_xs(x, y, 1, k, NULL);
|
||||||
}
|
}
|
||||||
|
@ -14706,7 +14706,7 @@ size_t quantize_iq2_s(const float * restrict src, void * restrict dst, int64_t n
|
||||||
return nrow * nblock * sizeof(block_iq2_s);
|
return nrow * nblock * sizeof(block_iq2_s);
|
||||||
}
|
}
|
||||||
|
|
||||||
void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restrict y, int64_t k) {
|
void quantize_row_iq2_s_ref(const float * restrict x, block_iq2_s * restrict y, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
quantize_iq2_s(x, y, 1, k, NULL);
|
quantize_iq2_s(x, y, 1, k, NULL);
|
||||||
}
|
}
|
||||||
|
@ -14714,7 +14714,7 @@ void quantize_row_iq2_s_reference(const float * restrict x, block_iq2_s * restri
|
||||||
void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) {
|
void quantize_row_iq2_s(const float * restrict x, void * restrict vy, int64_t k) {
|
||||||
assert(k % QK_K == 0);
|
assert(k % QK_K == 0);
|
||||||
block_iq2_s * restrict y = vy;
|
block_iq2_s * restrict y = vy;
|
||||||
quantize_row_iq2_s_reference(x, y, k);
|
quantize_row_iq2_s_ref(x, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool validate_float(float f, size_t i) {
|
static bool validate_float(float f, size_t i) {
|
||||||
|
|
|
@ -12,25 +12,25 @@ extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Quantization
|
// Quantization
|
||||||
void quantize_row_q4_0_reference(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_q4_1_reference(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_q5_0_reference(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_q5_1_reference(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_q8_0_reference(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_q8_1_reference(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k);
|
||||||
|
|
||||||
void quantize_row_q2_K_reference(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_q3_K_reference(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q3_K_ref(const float * GGML_RESTRICT x, block_q3_K * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_q4_K_reference(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q4_K_ref(const float * GGML_RESTRICT x, block_q4_K * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_q5_K_reference(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q5_K_ref(const float * GGML_RESTRICT x, block_q5_K * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_q6_K_reference(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q6_K_ref(const float * GGML_RESTRICT x, block_q6_K * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_q8_K_reference(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q8_K_ref(const float * GGML_RESTRICT x, block_q8_K * GGML_RESTRICT y, int64_t k);
|
||||||
|
|
||||||
void quantize_row_iq3_xxs_reference(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k);
|
void quantize_row_iq3_xxs_ref(const float * GGML_RESTRICT x, block_iq3_xxs * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_iq4_nl_reference (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k);
|
void quantize_row_iq4_nl_ref (const float * GGML_RESTRICT x, block_iq4_nl * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_iq4_xs_reference (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k);
|
void quantize_row_iq4_xs_ref (const float * GGML_RESTRICT x, block_iq4_xs * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_iq3_s_reference (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
|
void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_iq2_s_reference (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
|
void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k);
|
||||||
|
|
||||||
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||||
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
|
||||||
|
|
|
@ -3768,37 +3768,13 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
|
||||||
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
|
stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
|
||||||
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
|
SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
|
||||||
|
|
||||||
const ggml_tensor_extra_gpu *src0_extra =
|
|
||||||
(const ggml_tensor_extra_gpu *)src0->extra;
|
|
||||||
const ggml_tensor_extra_gpu *src1_extra =
|
|
||||||
(const ggml_tensor_extra_gpu *)src1->extra;
|
|
||||||
const ggml_tensor_extra_gpu *dst_extra =
|
|
||||||
(const ggml_tensor_extra_gpu *)dst->extra;
|
|
||||||
|
|
||||||
ggml_tensor_extra_gpu src0_row_extra;
|
|
||||||
ggml_tensor_extra_gpu src1_row_extra;
|
|
||||||
ggml_tensor_extra_gpu dst_row_extra;
|
|
||||||
|
|
||||||
ggml_tensor src0_row = *src0;
|
ggml_tensor src0_row = *src0;
|
||||||
ggml_tensor src1_row = *src1;
|
ggml_tensor src1_row = *src1;
|
||||||
ggml_tensor dst_row = *dst;
|
ggml_tensor dst_row = *dst;
|
||||||
|
|
||||||
src1_row.backend = GGML_BACKEND_TYPE_GPU;
|
char *src0_original = (char *)src0->data;
|
||||||
dst_row.backend = GGML_BACKEND_TYPE_GPU;
|
char *src1_original = (char *)src1->data;
|
||||||
|
char *dst_original = (char *)dst->data;
|
||||||
src0_row.extra = &src0_row_extra;
|
|
||||||
src1_row.extra = &src1_row_extra;
|
|
||||||
dst_row.extra = &dst_row_extra;
|
|
||||||
|
|
||||||
char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU
|
|
||||||
? (char *)src0->data
|
|
||||||
: (char *)src0_extra->data_device[ctx.device];
|
|
||||||
char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU
|
|
||||||
? (char *)src1->data
|
|
||||||
: (char *)src1_extra->data_device[ctx.device];
|
|
||||||
char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU
|
|
||||||
? (char *)dst->data
|
|
||||||
: (char *)dst_extra->data_device[ctx.device];
|
|
||||||
|
|
||||||
src0_row.ne[2] = 1;
|
src0_row.ne[2] = 1;
|
||||||
src0_row.ne[3] = 1;
|
src0_row.ne[3] = 1;
|
||||||
|
@ -3827,12 +3803,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
|
||||||
const int64_t i1 = id;
|
const int64_t i1 = id;
|
||||||
const int64_t i2 = i12;
|
const int64_t i2 = i12;
|
||||||
|
|
||||||
src0_row_extra.data_device[ctx.device] =
|
src0_row.data = src0_original + i02*nb02;
|
||||||
src0_original + i02*nb02;
|
src1_row.data = src1_original + + i11*nb11 + i12*nb12;
|
||||||
src1_row_extra.data_device[ctx.device] =
|
dst_row.data = dst_original + i1*nb1 + i2*nb2;
|
||||||
src1_original + + i11*nb11 + i12*nb12;
|
|
||||||
dst_row_extra.data_device[ctx.device] =
|
|
||||||
dst_original + i1*nb1 + i2*nb2;
|
|
||||||
|
|
||||||
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
|
||||||
}
|
}
|
||||||
|
@ -3841,8 +3814,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
|
||||||
ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
|
ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
|
||||||
ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
|
ggml_sycl_pool_alloc<char> dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
|
||||||
|
|
||||||
src1_row_extra.data_device[ctx.device] = src1_contiguous.get();
|
src1_row.data = src1_contiguous.get();
|
||||||
dst_row_extra.data_device[ctx.device] = dst_contiguous.get();
|
dst_row.data = dst_contiguous.get();
|
||||||
|
|
||||||
for (int64_t i02 = 0; i02 < n_as; i02++) {
|
for (int64_t i02 = 0; i02 < n_as; i02++) {
|
||||||
int64_t num_src1_rows = 0;
|
int64_t num_src1_rows = 0;
|
||||||
|
@ -3898,7 +3871,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
src0_row_extra.data_device[ctx.device] = src0_original + i02*nb02;
|
src0_row.data = src0_original + i02*nb02;
|
||||||
|
|
||||||
GGML_ASSERT(nb11 == sizeof(float)*ne10);
|
GGML_ASSERT(nb11 == sizeof(float)*ne10);
|
||||||
GGML_ASSERT(nb1 == sizeof(float)*ne0);
|
GGML_ASSERT(nb1 == sizeof(float)*ne0);
|
||||||
|
@ -5221,6 +5194,10 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
ggml_type src0_type = op->src[0]->type;
|
||||||
|
if (src0_type == GGML_TYPE_BF16) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
|
|
111
ggml/src/ggml.c
111
ggml/src/ggml.c
|
@ -600,7 +600,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = false,
|
.is_quantized = false,
|
||||||
.to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row,
|
.to_float = (ggml_to_float_t) ggml_fp16_to_fp32_row,
|
||||||
.from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row,
|
.from_float = (ggml_from_float_t) ggml_fp32_to_fp16_row,
|
||||||
.from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row,
|
.from_float_ref = (ggml_from_float_t) ggml_fp32_to_fp16_row,
|
||||||
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
|
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
|
||||||
.vec_dot_type = GGML_TYPE_F16,
|
.vec_dot_type = GGML_TYPE_F16,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -612,7 +612,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_q4_0,
|
.to_float = (ggml_to_float_t) dequantize_row_q4_0,
|
||||||
.from_float = quantize_row_q4_0,
|
.from_float = quantize_row_q4_0,
|
||||||
.from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference,
|
.from_float_ref = (ggml_from_float_t) quantize_row_q4_0_ref,
|
||||||
.vec_dot = ggml_vec_dot_q4_0_q8_0,
|
.vec_dot = ggml_vec_dot_q4_0_q8_0,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||||
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
||||||
|
@ -628,7 +628,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_q4_1,
|
.to_float = (ggml_to_float_t) dequantize_row_q4_1,
|
||||||
.from_float = quantize_row_q4_1,
|
.from_float = quantize_row_q4_1,
|
||||||
.from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference,
|
.from_float_ref = (ggml_from_float_t) quantize_row_q4_1_ref,
|
||||||
.vec_dot = ggml_vec_dot_q4_1_q8_1,
|
.vec_dot = ggml_vec_dot_q4_1_q8_1,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_1,
|
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||||
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
||||||
|
@ -644,7 +644,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = false,
|
.is_quantized = false,
|
||||||
.to_float = NULL,
|
.to_float = NULL,
|
||||||
.from_float = NULL,
|
.from_float = NULL,
|
||||||
.from_float_reference = NULL,
|
.from_float_ref = NULL,
|
||||||
.vec_dot = NULL,
|
.vec_dot = NULL,
|
||||||
.vec_dot_type = GGML_TYPE_COUNT,
|
.vec_dot_type = GGML_TYPE_COUNT,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -656,7 +656,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = false,
|
.is_quantized = false,
|
||||||
.to_float = NULL,
|
.to_float = NULL,
|
||||||
.from_float = NULL,
|
.from_float = NULL,
|
||||||
.from_float_reference = NULL,
|
.from_float_ref = NULL,
|
||||||
.vec_dot = NULL,
|
.vec_dot = NULL,
|
||||||
.vec_dot_type = GGML_TYPE_COUNT,
|
.vec_dot_type = GGML_TYPE_COUNT,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -668,7 +668,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_q5_0,
|
.to_float = (ggml_to_float_t) dequantize_row_q5_0,
|
||||||
.from_float = quantize_row_q5_0,
|
.from_float = quantize_row_q5_0,
|
||||||
.from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference,
|
.from_float_ref = (ggml_from_float_t) quantize_row_q5_0_ref,
|
||||||
.vec_dot = ggml_vec_dot_q5_0_q8_0,
|
.vec_dot = ggml_vec_dot_q5_0_q8_0,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -680,7 +680,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_q5_1,
|
.to_float = (ggml_to_float_t) dequantize_row_q5_1,
|
||||||
.from_float = quantize_row_q5_1,
|
.from_float = quantize_row_q5_1,
|
||||||
.from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference,
|
.from_float_ref = (ggml_from_float_t) quantize_row_q5_1_ref,
|
||||||
.vec_dot = ggml_vec_dot_q5_1_q8_1,
|
.vec_dot = ggml_vec_dot_q5_1_q8_1,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_1,
|
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -692,7 +692,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_q8_0,
|
.to_float = (ggml_to_float_t) dequantize_row_q8_0,
|
||||||
.from_float = quantize_row_q8_0,
|
.from_float = quantize_row_q8_0,
|
||||||
.from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
|
.from_float_ref = (ggml_from_float_t) quantize_row_q8_0_ref,
|
||||||
|
.from_float_to_mat = quantize_mat_q8_0,
|
||||||
.vec_dot = ggml_vec_dot_q8_0_q8_0,
|
.vec_dot = ggml_vec_dot_q8_0_q8_0,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||||
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
||||||
|
@ -700,7 +701,6 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
#else
|
#else
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
#endif
|
#endif
|
||||||
.from_float_to_mat = quantize_mat_q8_0,
|
|
||||||
},
|
},
|
||||||
[GGML_TYPE_Q8_1] = {
|
[GGML_TYPE_Q8_1] = {
|
||||||
.type_name = "q8_1",
|
.type_name = "q8_1",
|
||||||
|
@ -708,7 +708,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.type_size = sizeof(block_q8_1),
|
.type_size = sizeof(block_q8_1),
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.from_float = quantize_row_q8_1,
|
.from_float = quantize_row_q8_1,
|
||||||
.from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference,
|
.from_float_ref = (ggml_from_float_t) quantize_row_q8_1_ref,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_1,
|
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
},
|
},
|
||||||
|
@ -719,7 +719,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_q2_K,
|
.to_float = (ggml_to_float_t) dequantize_row_q2_K,
|
||||||
.from_float = quantize_row_q2_K,
|
.from_float = quantize_row_q2_K,
|
||||||
.from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference,
|
.from_float_ref = (ggml_from_float_t) quantize_row_q2_K_ref,
|
||||||
.vec_dot = ggml_vec_dot_q2_K_q8_K,
|
.vec_dot = ggml_vec_dot_q2_K_q8_K,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -731,7 +731,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_q3_K,
|
.to_float = (ggml_to_float_t) dequantize_row_q3_K,
|
||||||
.from_float = quantize_row_q3_K,
|
.from_float = quantize_row_q3_K,
|
||||||
.from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference,
|
.from_float_ref = (ggml_from_float_t) quantize_row_q3_K_ref,
|
||||||
.vec_dot = ggml_vec_dot_q3_K_q8_K,
|
.vec_dot = ggml_vec_dot_q3_K_q8_K,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -743,7 +743,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_q4_K,
|
.to_float = (ggml_to_float_t) dequantize_row_q4_K,
|
||||||
.from_float = quantize_row_q4_K,
|
.from_float = quantize_row_q4_K,
|
||||||
.from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference,
|
.from_float_ref = (ggml_from_float_t) quantize_row_q4_K_ref,
|
||||||
.vec_dot = ggml_vec_dot_q4_K_q8_K,
|
.vec_dot = ggml_vec_dot_q4_K_q8_K,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -755,7 +755,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_q5_K,
|
.to_float = (ggml_to_float_t) dequantize_row_q5_K,
|
||||||
.from_float = quantize_row_q5_K,
|
.from_float = quantize_row_q5_K,
|
||||||
.from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference,
|
.from_float_ref = (ggml_from_float_t) quantize_row_q5_K_ref,
|
||||||
.vec_dot = ggml_vec_dot_q5_K_q8_K,
|
.vec_dot = ggml_vec_dot_q5_K_q8_K,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -767,7 +767,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_q6_K,
|
.to_float = (ggml_to_float_t) dequantize_row_q6_K,
|
||||||
.from_float = quantize_row_q6_K,
|
.from_float = quantize_row_q6_K,
|
||||||
.from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference,
|
.from_float_ref = (ggml_from_float_t) quantize_row_q6_K_ref,
|
||||||
.vec_dot = ggml_vec_dot_q6_K_q8_K,
|
.vec_dot = ggml_vec_dot_q6_K_q8_K,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -779,7 +779,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_iq2_xxs,
|
.to_float = (ggml_to_float_t) dequantize_row_iq2_xxs,
|
||||||
.from_float = NULL,
|
.from_float = NULL,
|
||||||
.from_float_reference = NULL,
|
.from_float_ref = NULL,
|
||||||
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
|
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -791,7 +791,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_iq2_xs,
|
.to_float = (ggml_to_float_t) dequantize_row_iq2_xs,
|
||||||
.from_float = NULL,
|
.from_float = NULL,
|
||||||
.from_float_reference = NULL,
|
.from_float_ref = NULL,
|
||||||
.vec_dot = ggml_vec_dot_iq2_xs_q8_K,
|
.vec_dot = ggml_vec_dot_iq2_xs_q8_K,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -803,7 +803,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_iq3_xxs,
|
.to_float = (ggml_to_float_t) dequantize_row_iq3_xxs,
|
||||||
.from_float = quantize_row_iq3_xxs,
|
.from_float = quantize_row_iq3_xxs,
|
||||||
.from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference,
|
.from_float_ref = (ggml_from_float_t)quantize_row_iq3_xxs_ref,
|
||||||
.vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
|
.vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -815,7 +815,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_iq3_s,
|
.to_float = (ggml_to_float_t) dequantize_row_iq3_s,
|
||||||
.from_float = quantize_row_iq3_s,
|
.from_float = quantize_row_iq3_s,
|
||||||
.from_float_reference = (ggml_from_float_t)quantize_row_iq3_s_reference,
|
.from_float_ref = (ggml_from_float_t)quantize_row_iq3_s_ref,
|
||||||
.vec_dot = ggml_vec_dot_iq3_s_q8_K,
|
.vec_dot = ggml_vec_dot_iq3_s_q8_K,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -827,7 +827,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_iq2_s,
|
.to_float = (ggml_to_float_t) dequantize_row_iq2_s,
|
||||||
.from_float = quantize_row_iq2_s,
|
.from_float = quantize_row_iq2_s,
|
||||||
.from_float_reference = (ggml_from_float_t)quantize_row_iq2_s_reference,
|
.from_float_ref = (ggml_from_float_t)quantize_row_iq2_s_ref,
|
||||||
.vec_dot = ggml_vec_dot_iq2_s_q8_K,
|
.vec_dot = ggml_vec_dot_iq2_s_q8_K,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -839,7 +839,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_iq1_s,
|
.to_float = (ggml_to_float_t) dequantize_row_iq1_s,
|
||||||
.from_float = NULL,
|
.from_float = NULL,
|
||||||
.from_float_reference = NULL,
|
.from_float_ref = NULL,
|
||||||
.vec_dot = ggml_vec_dot_iq1_s_q8_K,
|
.vec_dot = ggml_vec_dot_iq1_s_q8_K,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -851,7 +851,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_iq1_m,
|
.to_float = (ggml_to_float_t) dequantize_row_iq1_m,
|
||||||
.from_float = NULL,
|
.from_float = NULL,
|
||||||
.from_float_reference = NULL,
|
.from_float_ref = NULL,
|
||||||
.vec_dot = ggml_vec_dot_iq1_m_q8_K,
|
.vec_dot = ggml_vec_dot_iq1_m_q8_K,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -863,7 +863,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_iq4_nl,
|
.to_float = (ggml_to_float_t) dequantize_row_iq4_nl,
|
||||||
.from_float = quantize_row_iq4_nl,
|
.from_float = quantize_row_iq4_nl,
|
||||||
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_nl_reference,
|
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_nl_ref,
|
||||||
.vec_dot = ggml_vec_dot_iq4_nl_q8_0,
|
.vec_dot = ggml_vec_dot_iq4_nl_q8_0,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -875,7 +875,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
|
.to_float = (ggml_to_float_t) dequantize_row_iq4_xs,
|
||||||
.from_float = quantize_row_iq4_xs,
|
.from_float = quantize_row_iq4_xs,
|
||||||
.from_float_reference = (ggml_from_float_t)quantize_row_iq4_xs_reference,
|
.from_float_ref = (ggml_from_float_t)quantize_row_iq4_xs_ref,
|
||||||
.vec_dot = ggml_vec_dot_iq4_xs_q8_K,
|
.vec_dot = ggml_vec_dot_iq4_xs_q8_K,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -894,7 +894,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
.is_quantized = false,
|
.is_quantized = false,
|
||||||
.to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
|
.to_float = (ggml_to_float_t) ggml_bf16_to_fp32_row,
|
||||||
.from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
|
.from_float = (ggml_from_float_t) ggml_fp32_to_bf16_row,
|
||||||
.from_float_reference = (ggml_from_float_t) ggml_fp32_to_bf16_row,
|
.from_float_ref = (ggml_from_float_t) ggml_fp32_to_bf16_row,
|
||||||
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
|
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_bf16,
|
||||||
.vec_dot_type = GGML_TYPE_BF16,
|
.vec_dot_type = GGML_TYPE_BF16,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
|
@ -902,48 +902,48 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||||
[GGML_TYPE_Q4_0_4_4] = {
|
[GGML_TYPE_Q4_0_4_4] = {
|
||||||
.type_name = "q4_0_4x4",
|
.type_name = "q4_0_4x4",
|
||||||
.blck_size = QK4_0,
|
.blck_size = QK4_0,
|
||||||
|
.blck_size_interleave = 4,
|
||||||
.type_size = sizeof(block_q4_0),
|
.type_size = sizeof(block_q4_0),
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = NULL,
|
.to_float = NULL,
|
||||||
.from_float = NULL,
|
.from_float = NULL,
|
||||||
.from_float_reference = NULL,
|
.from_float_ref = NULL,
|
||||||
.vec_dot = NULL,
|
.vec_dot = NULL,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
.ncols = 4,
|
.ncols = 4,
|
||||||
.interleave_blcksize = 4,
|
|
||||||
.gemv = ggml_gemv_q4_0_4x4_q8_0,
|
.gemv = ggml_gemv_q4_0_4x4_q8_0,
|
||||||
.gemm = ggml_gemm_q4_0_4x4_q8_0,
|
.gemm = ggml_gemm_q4_0_4x4_q8_0,
|
||||||
},
|
},
|
||||||
[GGML_TYPE_Q4_0_4_8] = {
|
[GGML_TYPE_Q4_0_4_8] = {
|
||||||
.type_name = "q4_0_4x8",
|
.type_name = "q4_0_4x8",
|
||||||
.blck_size = QK4_0,
|
.blck_size = QK4_0,
|
||||||
|
.blck_size_interleave = 8,
|
||||||
.type_size = sizeof(block_q4_0),
|
.type_size = sizeof(block_q4_0),
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = NULL,
|
.to_float = NULL,
|
||||||
.from_float = NULL,
|
.from_float = NULL,
|
||||||
.from_float_reference = NULL,
|
.from_float_ref = NULL,
|
||||||
.vec_dot = NULL,
|
.vec_dot = NULL,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
.ncols = 4,
|
.ncols = 4,
|
||||||
.interleave_blcksize = 8,
|
|
||||||
.gemv = ggml_gemv_q4_0_4x8_q8_0,
|
.gemv = ggml_gemv_q4_0_4x8_q8_0,
|
||||||
.gemm = ggml_gemm_q4_0_4x8_q8_0,
|
.gemm = ggml_gemm_q4_0_4x8_q8_0,
|
||||||
},
|
},
|
||||||
[GGML_TYPE_Q4_0_8_8] = {
|
[GGML_TYPE_Q4_0_8_8] = {
|
||||||
.type_name = "q4_0_8x8",
|
.type_name = "q4_0_8x8",
|
||||||
.blck_size = QK4_0,
|
.blck_size = QK4_0,
|
||||||
|
.blck_size_interleave = 8,
|
||||||
.type_size = sizeof(block_q4_0),
|
.type_size = sizeof(block_q4_0),
|
||||||
.is_quantized = true,
|
.is_quantized = true,
|
||||||
.to_float = NULL,
|
.to_float = NULL,
|
||||||
.from_float = NULL,
|
.from_float = NULL,
|
||||||
.from_float_reference = NULL,
|
.from_float_ref = NULL,
|
||||||
.vec_dot = NULL,
|
.vec_dot = NULL,
|
||||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||||
.nrows = 1,
|
.nrows = 1,
|
||||||
.ncols = 8,
|
.ncols = 8,
|
||||||
.interleave_blcksize = 8,
|
|
||||||
.gemv = ggml_gemv_q4_0_8x8_q8_0,
|
.gemv = ggml_gemv_q4_0_8x8_q8_0,
|
||||||
.gemm = ggml_gemm_q4_0_8x8_q8_0,
|
.gemm = ggml_gemm_q4_0_8x8_q8_0,
|
||||||
}
|
}
|
||||||
|
@ -3135,7 +3135,7 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) {
|
||||||
return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN);
|
return GGML_PAD(ggml_nbytes(tensor), GGML_MEM_ALIGN);
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL int ggml_blck_size(enum ggml_type type) {
|
GGML_CALL int64_t ggml_blck_size(enum ggml_type type) {
|
||||||
return type_traits[type].blck_size;
|
return type_traits[type].blck_size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -12238,15 +12238,14 @@ static void ggml_compute_forward_mul_mat(
|
||||||
|
|
||||||
const enum ggml_type type = src0->type;
|
const enum ggml_type type = src0->type;
|
||||||
|
|
||||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||||
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float;
|
||||||
int64_t const vec_dot_num_rows = type_traits[type].nrows;
|
ggml_from_float_to_mat_t const from_float_to_mat = type_traits[vec_dot_type].from_float_to_mat;
|
||||||
int64_t const matmul_num_cols = type_traits[type].ncols;
|
int64_t const vec_dot_num_rows = type_traits[type].nrows;
|
||||||
int64_t const interleave_blcksize = type_traits[type].interleave_blcksize;
|
int64_t const matmul_num_cols = type_traits[type].ncols;
|
||||||
ggml_from_float_to_mat_t const from_float_to_mat
|
int64_t const blck_size_interleave = type_traits[type].blck_size_interleave;
|
||||||
= type_traits[vec_dot_type].from_float_to_mat;
|
ggml_gemv_t const gemv = type_traits[type].gemv;
|
||||||
ggml_gemv_t const gemv = type_traits[type].gemv;
|
ggml_gemm_t const gemm = type_traits[type].gemm;
|
||||||
ggml_gemm_t const gemm = type_traits[type].gemm;
|
|
||||||
|
|
||||||
GGML_ASSERT(ne0 == ne01);
|
GGML_ASSERT(ne0 == ne01);
|
||||||
GGML_ASSERT(ne1 == ne11);
|
GGML_ASSERT(ne1 == ne11);
|
||||||
|
@ -12318,14 +12317,14 @@ UseGgmlGemm1:;
|
||||||
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
for (int64_t i11 = ith * 4; i11 < ne11 - ne11 % 4; i11 += nth * 4) {
|
||||||
from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
from_float_to_mat((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
||||||
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
||||||
4, ne10, interleave_blcksize);
|
4, ne10, blck_size_interleave);
|
||||||
}
|
}
|
||||||
i11_processed = ne11 - ne11 % 4;
|
i11_processed = ne11 - ne11 % 4;
|
||||||
}
|
}
|
||||||
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
for (int64_t i11 = i11_processed + ith; i11 < ne11; i11 += nth) {
|
||||||
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
||||||
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
||||||
ne10);
|
ne10);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -12409,7 +12408,7 @@ UseGgmlGemm2:;
|
||||||
int64_t src0_start = (ith * ne01) / nth;
|
int64_t src0_start = (ith * ne01) / nth;
|
||||||
int64_t src0_end = ((ith + 1) * ne01) / nth;
|
int64_t src0_end = ((ith + 1) * ne01) / nth;
|
||||||
src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
|
src0_start = (src0_start % matmul_num_cols) ? src0_start + matmul_num_cols - (src0_start % matmul_num_cols): src0_start;
|
||||||
src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end;
|
src0_end = (src0_end % matmul_num_cols) ? src0_end + matmul_num_cols - (src0_end % matmul_num_cols): src0_end;
|
||||||
if (src0_start >= src0_end) return;
|
if (src0_start >= src0_end) return;
|
||||||
|
|
||||||
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
||||||
|
@ -12467,11 +12466,11 @@ static void ggml_compute_forward_mul_mat_id(
|
||||||
|
|
||||||
const bool src1_cont = ggml_is_contiguous(src1);
|
const bool src1_cont = ggml_is_contiguous(src1);
|
||||||
|
|
||||||
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
|
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
|
||||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||||
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
ggml_from_float_t const from_float = type_traits[vec_dot_type].from_float;
|
||||||
int64_t const matmul_num_cols = type_traits[type].ncols;
|
int64_t const matmul_num_cols = type_traits[type].ncols;
|
||||||
ggml_gemv_t const gemv = type_traits[type].gemv;
|
ggml_gemv_t const gemv = type_traits[type].gemv;
|
||||||
|
|
||||||
// we don't support permuted src0 or src1
|
// we don't support permuted src0 or src1
|
||||||
GGML_ASSERT(nb00 == ggml_type_size(type));
|
GGML_ASSERT(nb00 == ggml_type_size(type));
|
||||||
|
@ -12512,9 +12511,9 @@ static void ggml_compute_forward_mul_mat_id(
|
||||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||||
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
for (int64_t i11 = ith; i11 < ne11; i11 += nth) {
|
||||||
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
from_float((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11),
|
||||||
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
(void *) (wdata + i13*nbw3 + i12*nbw2 + i11*nbw1),
|
||||||
ne10);
|
ne10);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -21166,8 +21165,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
|
||||||
(int64_t) info->ne[3];
|
(int64_t) info->ne[3];
|
||||||
|
|
||||||
if (ne % ggml_blck_size(info->type) != 0) {
|
if (ne % ggml_blck_size(info->type) != 0) {
|
||||||
fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%d)\n",
|
fprintf(stderr, "%s: tensor '%s' of type %d (%s) number of elements (%" PRId64 ") is not a multiple of block size (%" PRId64 ")\n",
|
||||||
__func__, info->name.data, (int)info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
|
__func__, info->name.data, (int) info->type, ggml_type_name(info->type), ne, ggml_blck_size(info->type));
|
||||||
fclose(file);
|
fclose(file);
|
||||||
gguf_free(ctx);
|
gguf_free(ctx);
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|
|
@ -5955,13 +5955,6 @@ static bool llm_load_tensors(
|
||||||
|
|
||||||
auto & hparams = model.hparams;
|
auto & hparams = model.hparams;
|
||||||
|
|
||||||
#ifdef GGML_USE_SYCL
|
|
||||||
// disable MoE with SYCL until mul_mat_id is updated
|
|
||||||
if (hparams.n_expert > 0) {
|
|
||||||
n_gpu_layers = 0;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
model.split_mode = split_mode;
|
model.split_mode = split_mode;
|
||||||
model.main_gpu = main_gpu;
|
model.main_gpu = main_gpu;
|
||||||
model.n_gpu_layers = n_gpu_layers;
|
model.n_gpu_layers = n_gpu_layers;
|
||||||
|
@ -21500,7 +21493,7 @@ int32_t llama_token_to_piece(const struct llama_model * model, llama_token token
|
||||||
size--;
|
size--;
|
||||||
}
|
}
|
||||||
if (length < (int32_t)size) {
|
if (length < (int32_t)size) {
|
||||||
return (int32_t) -size;
|
return -(int32_t) size;
|
||||||
}
|
}
|
||||||
memcpy(buf, token, size);
|
memcpy(buf, token, size);
|
||||||
return (int32_t) size;
|
return (int32_t) size;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue